Skip to content

Commit 1878f6d

Browse files
Fix CartPole for equivalence to vector implementation (#1121)
1 parent a5d518a commit 1878f6d

File tree

3 files changed

+45
-46
lines changed

3 files changed

+45
-46
lines changed

gymnasium/envs/classic_control/cartpole.py

+22-29
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(
157157
self.screen = None
158158
self.clock = None
159159
self.isopen = True
160-
self.state = None
160+
self.state: np.ndarray | None = None
161161

162162
self.steps_beyond_terminated = None
163163

@@ -168,16 +168,17 @@ def step(self, action):
168168
assert self.state is not None, "Call reset before using step method."
169169
x, x_dot, theta, theta_dot = self.state
170170
force = self.force_mag if action == 1 else -self.force_mag
171-
costheta = math.cos(theta)
172-
sintheta = math.sin(theta)
171+
costheta = np.cos(theta)
172+
sintheta = np.sin(theta)
173173

174174
# For the interested reader:
175175
# https://coneural.org/florian/papers/05_cart_pole.pdf
176176
temp = (
177-
force + self.polemass_length * theta_dot**2 * sintheta
177+
force + self.polemass_length * np.square(theta_dot) * sintheta
178178
) / self.total_mass
179179
thetaacc = (self.gravity * sintheta - costheta * temp) / (
180-
self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
180+
self.length
181+
* (4.0 / 3.0 - self.masspole * np.square(costheta) / self.total_mass)
181182
)
182183
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
183184

@@ -192,7 +193,7 @@ def step(self, action):
192193
theta_dot = theta_dot + self.tau * thetaacc
193194
theta = theta + self.tau * theta_dot
194195

195-
self.state = (x, x_dot, theta, theta_dot)
196+
self.state = np.array((x, x_dot, theta, theta_dot), dtype=np.float64)
196197

197198
terminated = bool(
198199
x < -self.x_threshold
@@ -202,33 +203,25 @@ def step(self, action):
202203
)
203204

204205
if not terminated:
205-
if self._sutton_barto_reward:
206-
reward = 0.0
207-
elif not self._sutton_barto_reward:
208-
reward = 1.0
206+
reward = 0.0 if self._sutton_barto_reward else 1.0
209207
elif self.steps_beyond_terminated is None:
210208
# Pole just fell!
211209
self.steps_beyond_terminated = 0
212-
if self._sutton_barto_reward:
213-
reward = -1.0
214-
else:
215-
reward = 1.0
210+
211+
reward = -1.0 if self._sutton_barto_reward else 1.0
216212
else:
217213
if self.steps_beyond_terminated == 0:
218214
logger.warn(
219-
"You are calling 'step()' even though this "
220-
"environment has already returned terminated = True. You "
221-
"should always call 'reset()' once you receive 'terminated = "
222-
"True' -- any further steps are undefined behavior."
215+
"You are calling 'step()' even though this environment has already returned terminated = True. "
216+
"You should always call 'reset()' once you receive 'terminated = True' -- any further steps are undefined behavior."
223217
)
224218
self.steps_beyond_terminated += 1
225-
if self._sutton_barto_reward:
226-
reward = -1.0
227-
else:
228-
reward = 0.0
219+
220+
reward = -1.0 if self._sutton_barto_reward else 0.0
229221

230222
if self.render_mode == "human":
231223
self.render()
224+
232225
# truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make`
233226
return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
234227

@@ -439,10 +432,11 @@ def step(
439432
# For the interested reader:
440433
# https://coneural.org/florian/papers/05_cart_pole.pdf
441434
temp = (
442-
force + self.polemass_length * theta_dot**2 * sintheta
435+
force + self.polemass_length * np.square(theta_dot) * sintheta
443436
) / self.total_mass
444437
thetaacc = (self.gravity * sintheta - costheta * temp) / (
445-
self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
438+
self.length
439+
* (4.0 / 3.0 - self.masspole * np.square(costheta) / self.total_mass)
446440
)
447441
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
448442

@@ -470,7 +464,7 @@ def step(
470464

471465
truncated = self.steps >= self.max_episode_steps
472466

473-
if self._sutton_barto_reward is True:
467+
if self._sutton_barto_reward:
474468
reward = -np.array(terminated, dtype=np.float32)
475469
else:
476470
reward = np.ones_like(terminated, dtype=np.float32)
@@ -484,7 +478,7 @@ def step(
484478
terminated[self.prev_done] = False
485479
truncated[self.prev_done] = False
486480

487-
self.prev_done = terminated | truncated
481+
self.prev_done = np.logical_or(terminated, truncated)
488482

489483
return self.state.T.astype(np.float32), reward, terminated, truncated, {}
490484

@@ -497,9 +491,8 @@ def reset(
497491
super().reset(seed=seed)
498492
# Note that if you use custom reset bounds, it may lead to out-of-bound
499493
# state/observations.
500-
self.low, self.high = utils.maybe_parse_reset_bounds(
501-
options, -0.05, 0.05 # default low
502-
) # default high
494+
# -0.05 and 0.05 is the default low and high bounds
495+
self.low, self.high = utils.maybe_parse_reset_bounds(options, -0.05, 0.05)
503496
self.state = self.np_random.uniform(
504497
low=self.low, high=self.high, size=(4, self.num_envs)
505498
)

gymnasium/wrappers/vector/dict_info_to_list.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ class DictInfoToList(VectorWrapper):
3737
Example for vector environments:
3838
>>> import numpy as np
3939
>>> import gymnasium as gym
40-
>>> from gymnasium.spaces import Dict, Box
4140
>>> envs = gym.make_vec("CartPole-v1", num_envs=3)
4241
>>> obs, info = envs.reset(seed=123)
4342
>>> info
@@ -61,7 +60,7 @@ class DictInfoToList(VectorWrapper):
6160
>>> _ = envs.action_space.seed(123)
6261
>>> _, _, _, _, infos = envs.step(envs.action_space.sample())
6362
>>> infos
64-
[{'x_position': np.float64(0.03332210900362942), 'x_velocity': np.float64(-0.06296527291998533), 'reward_run': np.float64(-0.06296527291998533), 'reward_ctrl': np.float32(-0.24503504)}, {'x_position': np.float64(0.10172354684460168), 'x_velocity': np.float64(0.8934584807363618), 'reward_run': np.float64(0.8934584807363618), 'reward_ctrl': np.float32(-0.21944423)}]
63+
[{'x_position': np.float64(0.0333221090036294), 'x_velocity': np.float64(-0.06296527291998574), 'reward_run': np.float64(-0.06296527291998574), 'reward_ctrl': np.float32(-0.24503504)}, {'x_position': np.float64(0.10172354684460168), 'x_velocity': np.float64(0.8934584807363618), 'reward_run': np.float64(0.8934584807363618), 'reward_ctrl': np.float32(-0.21944423)}]
6564
6665
Change logs:
6766
* v0.24.0 - Initially added as ``VectorListInfo``

tests/envs/test_env_implementation.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,15 @@ def test_cartpole_vector_equiv():
283283
assert env.action_space == envs.single_action_space
284284
assert env.observation_space == envs.single_observation_space
285285

286-
# reset
286+
# for seed in range(0, 10_000):
287287
seed = np.random.randint(0, 1000)
288+
289+
# reset
288290
obs, info = env.reset(seed=seed)
289291
vec_obs, vec_info = envs.reset(seed=seed)
290292

293+
env.action_space.seed(seed=seed)
294+
291295
assert obs in env.observation_space
292296
assert vec_obs in envs.observation_space
293297
assert np.all(obs == vec_obs[0])
@@ -315,24 +319,27 @@ def test_cartpole_vector_equiv():
315319

316320
assert np.all(env.unwrapped.state == envs.unwrapped.state[:, 0])
317321

318-
if term:
322+
if term or trunc:
319323
break
320324

321-
obs, info = env.reset()
322-
# the vector action shouldn't matter as autoreset
323-
vec_obs, vec_reward, vec_term, vec_trunc, vec_info = envs.step(
324-
envs.action_space.sample()
325-
)
325+
# if the sub-environment episode ended
326+
if term or trunc:
327+
obs, info = env.reset()
328+
# the vector action shouldn't matter as autoreset
329+
assert envs.unwrapped.prev_done
330+
vec_obs, vec_reward, vec_term, vec_trunc, vec_info = envs.step(
331+
envs.action_space.sample()
332+
)
326333

327-
assert obs in env.observation_space
328-
assert vec_obs in envs.observation_space
329-
assert np.all(obs == vec_obs[0])
330-
assert vec_reward == np.array([0])
331-
assert vec_term == np.array([False])
332-
assert vec_trunc == np.array([False])
333-
assert info == vec_info
334+
assert obs in env.observation_space
335+
assert vec_obs in envs.observation_space
336+
assert np.all(obs == vec_obs[0])
337+
assert vec_reward == np.array([0])
338+
assert vec_term == np.array([False])
339+
assert vec_trunc == np.array([False])
340+
assert info == vec_info
334341

335-
assert np.all(env.unwrapped.state == envs.unwrapped.state[:, 0])
342+
assert np.all(env.unwrapped.state == envs.unwrapped.state[:, 0])
336343

337344
env.close()
338345
envs.close()

0 commit comments

Comments
 (0)