@@ -157,7 +157,7 @@ def __init__(
157
157
self .screen = None
158
158
self .clock = None
159
159
self .isopen = True
160
- self .state = None
160
+ self .state : np . ndarray | None = None
161
161
162
162
self .steps_beyond_terminated = None
163
163
@@ -168,16 +168,17 @@ def step(self, action):
168
168
assert self .state is not None , "Call reset before using step method."
169
169
x , x_dot , theta , theta_dot = self .state
170
170
force = self .force_mag if action == 1 else - self .force_mag
171
- costheta = math .cos (theta )
172
- sintheta = math .sin (theta )
171
+ costheta = np .cos (theta )
172
+ sintheta = np .sin (theta )
173
173
174
174
# For the interested reader:
175
175
# https://coneural.org/florian/papers/05_cart_pole.pdf
176
176
temp = (
177
- force + self .polemass_length * theta_dot ** 2 * sintheta
177
+ force + self .polemass_length * np . square ( theta_dot ) * sintheta
178
178
) / self .total_mass
179
179
thetaacc = (self .gravity * sintheta - costheta * temp ) / (
180
- self .length * (4.0 / 3.0 - self .masspole * costheta ** 2 / self .total_mass )
180
+ self .length
181
+ * (4.0 / 3.0 - self .masspole * np .square (costheta ) / self .total_mass )
181
182
)
182
183
xacc = temp - self .polemass_length * thetaacc * costheta / self .total_mass
183
184
@@ -192,7 +193,7 @@ def step(self, action):
192
193
theta_dot = theta_dot + self .tau * thetaacc
193
194
theta = theta + self .tau * theta_dot
194
195
195
- self .state = ( x , x_dot , theta , theta_dot )
196
+ self .state = np . array (( x , x_dot , theta , theta_dot ), dtype = np . float64 )
196
197
197
198
terminated = bool (
198
199
x < - self .x_threshold
@@ -202,33 +203,25 @@ def step(self, action):
202
203
)
203
204
204
205
if not terminated :
205
- if self ._sutton_barto_reward :
206
- reward = 0.0
207
- elif not self ._sutton_barto_reward :
208
- reward = 1.0
206
+ reward = 0.0 if self ._sutton_barto_reward else 1.0
209
207
elif self .steps_beyond_terminated is None :
210
208
# Pole just fell!
211
209
self .steps_beyond_terminated = 0
212
- if self ._sutton_barto_reward :
213
- reward = - 1.0
214
- else :
215
- reward = 1.0
210
+
211
+ reward = - 1.0 if self ._sutton_barto_reward else 1.0
216
212
else :
217
213
if self .steps_beyond_terminated == 0 :
218
214
logger .warn (
219
- "You are calling 'step()' even though this "
220
- "environment has already returned terminated = True. You "
221
- "should always call 'reset()' once you receive 'terminated = "
222
- "True' -- any further steps are undefined behavior."
215
+ "You are calling 'step()' even though this environment has already returned terminated = True. "
216
+ "You should always call 'reset()' once you receive 'terminated = True' -- any further steps are undefined behavior."
223
217
)
224
218
self .steps_beyond_terminated += 1
225
- if self ._sutton_barto_reward :
226
- reward = - 1.0
227
- else :
228
- reward = 0.0
219
+
220
+ reward = - 1.0 if self ._sutton_barto_reward else 0.0
229
221
230
222
if self .render_mode == "human" :
231
223
self .render ()
224
+
232
225
# truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make`
233
226
return np .array (self .state , dtype = np .float32 ), reward , terminated , False , {}
234
227
@@ -439,10 +432,11 @@ def step(
439
432
# For the interested reader:
440
433
# https://coneural.org/florian/papers/05_cart_pole.pdf
441
434
temp = (
442
- force + self .polemass_length * theta_dot ** 2 * sintheta
435
+ force + self .polemass_length * np . square ( theta_dot ) * sintheta
443
436
) / self .total_mass
444
437
thetaacc = (self .gravity * sintheta - costheta * temp ) / (
445
- self .length * (4.0 / 3.0 - self .masspole * costheta ** 2 / self .total_mass )
438
+ self .length
439
+ * (4.0 / 3.0 - self .masspole * np .square (costheta ) / self .total_mass )
446
440
)
447
441
xacc = temp - self .polemass_length * thetaacc * costheta / self .total_mass
448
442
@@ -470,7 +464,7 @@ def step(
470
464
471
465
truncated = self .steps >= self .max_episode_steps
472
466
473
- if self ._sutton_barto_reward is True :
467
+ if self ._sutton_barto_reward :
474
468
reward = - np .array (terminated , dtype = np .float32 )
475
469
else :
476
470
reward = np .ones_like (terminated , dtype = np .float32 )
@@ -484,7 +478,7 @@ def step(
484
478
terminated [self .prev_done ] = False
485
479
truncated [self .prev_done ] = False
486
480
487
- self .prev_done = terminated | truncated
481
+ self .prev_done = np . logical_or ( terminated , truncated )
488
482
489
483
return self .state .T .astype (np .float32 ), reward , terminated , truncated , {}
490
484
@@ -497,9 +491,8 @@ def reset(
497
491
super ().reset (seed = seed )
498
492
# Note that if you use custom reset bounds, it may lead to out-of-bound
499
493
# state/observations.
500
- self .low , self .high = utils .maybe_parse_reset_bounds (
501
- options , - 0.05 , 0.05 # default low
502
- ) # default high
494
+ # -0.05 and 0.05 is the default low and high bounds
495
+ self .low , self .high = utils .maybe_parse_reset_bounds (options , - 0.05 , 0.05 )
503
496
self .state = self .np_random .uniform (
504
497
low = self .low , high = self .high , size = (4 , self .num_envs )
505
498
)
0 commit comments