Skip to content

Commit a7eedd7

Browse files
committed
adjusted pettingzoo to PettingZooEnv simultaneous environment more convenient (#925)
1 parent 309fbf5 commit a7eedd7

File tree

2 files changed

+63
-58
lines changed

2 files changed

+63
-58
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
using .PyCall
22

3-
43
np = pyimport("numpy")
54

6-
export PettingzooEnv
5+
export PettingZooEnv
6+
77

88
"""
9-
PettingzooEnv(;kwargs...)
9+
PettingZooEnv(;kwargs...)
1010
11-
`PettingzooEnv` is an interface of the python library pettingzoo for multi agent reinforcement learning environments. It can be used to test multi
11+
`PettingZooEnv` is an interface of the python library Pettingzoo for multi agent reinforcement learning environments. It can be used to test multi
1212
agent reinforcement learning algorithms implemented in JUlia ReinforcementLearning.
1313
"""
14-
function PettingzooEnv(name::String; seed=123, args...)
14+
15+
function PettingZooEnv(name::String; seed=123, args...)
1516
if !PyCall.pyexists("pettingzoo.$name")
1617
error("Cannot import pettingzoo.$name")
1718
end
@@ -20,7 +21,7 @@ function PettingzooEnv(name::String; seed=123, args...)
2021
pyenv.reset(seed=seed)
2122
obs_space = space_transform(pyenv.observation_space(pyenv.agents[1]))
2223
act_space = space_transform(pyenv.action_space(pyenv.agents[1]))
23-
env = PettingzooEnv{typeof(act_space),typeof(obs_space),typeof(pyenv)}(
24+
env = PettingZooEnv{typeof(act_space),typeof(obs_space),typeof(pyenv)}(
2425
pyenv,
2526
obs_space,
2627
act_space,
@@ -33,13 +34,12 @@ end
3334

3435
# basic function needed for simulation ========================================================================
3536

36-
function RLBase.reset!(env::PettingzooEnv)
37+
function RLBase.reset!(env::PettingZooEnv)
3738
pycall!(env.state, env.pyenv.reset, PyObject, env.seed)
38-
env.ts = 1
3939
nothing
4040
end
4141

42-
function RLBase.is_terminated(env::PettingzooEnv)
42+
function RLBase.is_terminated(env::PettingZooEnv)
4343
_, _, t, d, _ = pycall(env.pyenv.last, PyObject)
4444
t || d
4545
end
@@ -48,96 +48,96 @@ end
4848

4949
## State / observation implementations ========================================================================
5050

51-
RLBase.state(env::PettingzooEnv, ::Observation{Any}, players::Tuple) = Dict(p => state(env, p) for p in players)
51+
RLBase.state(env::PettingZooEnv, ::Observation{Any}, players::Tuple) = Dict(p => state(env, p) for p in players)
5252

5353

5454
# partial observability is default for pettingzoo
55-
function RLBase.state(env::PettingzooEnv, ::Observation{Any}, player)
55+
function RLBase.state(env::PettingZooEnv, ::Observation{Any}, player)
5656
env.pyenv.observe(player)
5757
end
5858

5959

6060
## state space =========================================================================================================================================
6161

62-
RLBase.state_space(env::PettingzooEnv, ::Observation{Any}, players) = Space(Dict(player => state_space(env, player) for player in players))
62+
RLBase.state_space(env::PettingZooEnv, ::Observation{Any}, players) = Space(Dict(player => state_space(env, player) for player in players))
6363

6464
# partial observability
65-
RLBase.state_space(env::PettingzooEnv, ::Observation{Any}, player::String) = space_transform(env.pyenv.observation_space(player))
65+
RLBase.state_space(env::PettingZooEnv, ::Observation{Any}, player::Symbol) = space_transform(env.pyenv.observation_space(String(player)))
6666

6767
# for full observability. Be careful: action_space has also to be adjusted
68-
# RLBase.state_space(env::PettingzooEnv, ::Observation{Any}, player::String) = space_transform(env.pyenv.state_space)
68+
# RLBase.state_space(env::PettingZooEnv, ::Observation{Any}, player::String) = space_transform(env.pyenv.state_space)
6969

7070

7171
## action space implementations ====================================================================================
7272

73-
RLBase.action_space(env::PettingzooEnv, players::Tuple{String}) =
73+
RLBase.action_space(env::PettingZooEnv, players::Tuple{Symbol}) =
7474
Space(Dict(p => action_space(env, p) for p in players))
7575

76-
RLBase.action_space(env::PettingzooEnv, player::String) = space_transform(env.pyenv.action_space(player))
76+
RLBase.action_space(env::PettingZooEnv, player::Symbol) = space_transform(env.pyenv.action_space(String(player)))
7777

78-
RLBase.action_space(env::PettingzooEnv, player::Integer) = space_transform(env.pyenv.action_space(env.pyenv.agents[player]))
78+
RLBase.action_space(env::PettingZooEnv, player::Integer) = space_transform(env.pyenv.action_space(env.pyenv.agents[player]))
7979

80-
RLBase.action_space(env::PettingzooEnv, player::DefaultPlayer) = env.action_space
80+
RLBase.action_space(env::PettingZooEnv, player::DefaultPlayer) = env.action_space
8181

8282
## action functions ========================================================================================================================
8383

84-
function RLBase.act!(env::PettingzooEnv, actions::Dict, players::Tuple)
85-
@assert length(actions) == length(players)
86-
env.ts += 1
87-
for p in players
88-
env(actions[p])
84+
function RLBase.act!(env::PettingZooEnv, actions::Dict{Symbol, Int})
85+
@assert length(actions) == length(players(env))
86+
for p in env.pyenv.agents
87+
pycall(env.pyenv.step, PyObject, actions[p])
8988
end
9089
end
9190

92-
function RLBase.act!(env::PettingzooEnv, actions::Dict, player)
93-
@assert length(actions) == length(players(env))
94-
for p in players(env)
95-
env(actions[p])
91+
function RLBase.act!(env::PettingZooEnv, actions::Dict{Symbol, Real})
92+
@assert length(actions) == length(env.pyenv.agents)
93+
for p in env.pyenv.agents
94+
pycall(env.pyenv.step, PyObject, np.array(actions[p]; dtype=np.float32))
9695
end
9796
end
9897

99-
function RLBase.act!(env::PettingzooEnv, actions::Dict{String, Int})
100-
@assert length(actions) == length(players(env))
98+
function RLBase.act!(env::PettingZooEnv, actions::Dict{Symbol, Vector})
99+
@assert length(actions) == length(env.pyenv.agents)
101100
for p in env.pyenv.agents
102-
pycall(env.pyenv.step, PyObject, actions[p])
101+
RLBase.act!(env, p)
103102
end
104103
end
105104

106-
function RLBase.act!(env::PettingzooEnv, actions::Dict{String, Real})
107-
@assert length(actions) == length(players(env))
108-
env.ts += 1
109-
for p in env.pyenv.agents
110-
pycall(env.pyenv.step, PyObject, np.array(actions[p]; dtype=np.float32))
105+
function RLBase.act!(env::PettingZooEnv, actions::NamedTuple)
106+
@assert length(actions) == length(env.pyenv.agents)
107+
for player players(env)
108+
RLBase.act!(env, actions[player])
111109
end
112110
end
113111

114-
function RLBase.act!(env::PettingzooEnv, action::Vector)
112+
# for vectors, pettingzoo need them to be in proper numpy type
113+
function RLBase.act!(env::PettingZooEnv, action::Vector)
115114
pycall(env.pyenv.step, PyObject, np.array(action; dtype=np.float32))
116115
end
117116

118-
function RLBase.act!(env::PettingzooEnv, action::Integer)
119-
env.ts += 1
117+
function RLBase.act!(env::PettingZooEnv, action)
120118
pycall(env.pyenv.step, PyObject, action)
121119
end
122120

123121
# reward of player ======================================================================================================================
124-
function RLBase.reward(env::PettingzooEnv, player::String)
125-
env.pyenv.rewards[player]
122+
function RLBase.reward(env::PettingZooEnv, player::Symbol)
123+
env.pyenv.rewards[String(player)]
126124
end
127125

128126

129127
# Multi agent part =========================================================================================================================================
130128

131129

132-
RLBase.players(env::PettingzooEnv) = env.pyenv.agents
130+
RLBase.players(env::PettingZooEnv) = Symbol.(env.pyenv.agents)
131+
132+
function RLBase.current_player(env::PettingZooEnv)
133+
return Symbol(env.pyenv.agents[env.current_player])
134+
end
133135

134-
function RLBase.current_player(env::PettingzooEnv, post_action=false)
135-
cur_id = env.ts % length(env.pyenv.agents) == 0 ? length(env.pyenv.agents) : env.ts % length(env.pyenv.agents)
136-
cur_id = post_action ? (cur_id - 1 == 0 ? length(env.pyenv.agents) : cur_id - 1) : cur_id
137-
return env.pyenv.agents[cur_id]
136+
function RLBase.next_player!(env::PettingZooEnv)
137+
env.current_player = env.current_player < length(env.pyenv.agents) ? env.current_player + 1 : 1
138138
end
139139

140-
function RLBase.NumAgentStyle(env::PettingzooEnv)
140+
function RLBase.NumAgentStyle(env::PettingZooEnv)
141141
n = length(env.pyenv.agents)
142142
if n == 1
143143
SingleAgent()
@@ -146,9 +146,8 @@ function RLBase.NumAgentStyle(env::PettingzooEnv)
146146
end
147147
end
148148

149-
150-
RLBase.DynamicStyle(::PettingzooEnv) = SEQUENTIAL
151-
RLBase.ActionStyle(::PettingzooEnv) = MINIMAL_ACTION_SET
152-
RLBase.InformationStyle(::PettingzooEnv) = IMPERFECT_INFORMATION
153-
RLBase.ChanceStyle(::PettingzooEnv) = EXPLICIT_STOCHASTIC
149+
RLBase.DynamicStyle(::PettingZooEnv) = SIMULTANEOUS
150+
RLBase.ActionStyle(::PettingZooEnv) = MINIMAL_ACTION_SET
151+
RLBase.InformationStyle(::PettingZooEnv) = IMPERFECT_INFORMATION
152+
RLBase.ChanceStyle(::PettingZooEnv) = EXPLICIT_STOCHASTIC
154153

src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl

+13-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1-
mutable struct PettingzooEnv{Ta,To,P} <: AbstractEnv
2-
pyenv::P
3-
observation_space::To
4-
action_space::Ta
5-
state::P
6-
seed::Union{Int, Nothing}
7-
ts::Int
1+
# Parametrization:
2+
# Ta : Type of action_space
3+
# To : Type of observation_space
4+
# P : Type of environment most common: PyObject
5+
6+
mutable struct PettingZooEnv{Ta,To,P} <: AbstractEnv
7+
pyenv::P
8+
observation_space::To
9+
action_space::Ta
10+
state::P
11+
seed::Union{Int, Nothing}
12+
current_player::Int
813
end
14+
915
export PettingzooEnv
1016

1117
struct GymEnv{T,Ta,To,P} <: AbstractEnv

0 commit comments

Comments
 (0)