Skip to content

Commit 8d6154b

Browse files
committed
Cleaner code and more complete testing
1 parent 115fd3e commit 8d6154b

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

gymnasium/envs/toy_text/taxi.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class TaxiEnv(Env):
152152
153153
## Version History
154154
* v3: Map Correction + Cleaner Domain Description, v0.25.0 action masking added to the reset and step information
155-
- In Gymnasium `1.0.0a3` the `is_rainy` and `fickle_passenger` arguments were added to align with Dietterich paper
155+
- In Gymnasium `1.1.0` the `is_rainy` and `fickle_passenger` arguments were added to align with Dietterich paper
156156
* v2: Disallow Taxi start location = goal location, Update Taxi observations in the rollout, Update Taxi reward threshold.
157157
* v1: Remove (3,2) from locs, add passidx<4 check
158158
* v0: Initial version release
@@ -290,7 +290,7 @@ def __init__(
290290

291291
self.render_mode = render_mode
292292
self.fickle_passenger = fickle_passenger
293-
self.fickle_step = True
293+
self.fickle_step = self.fickle_passenger and self.np_random.random() < 0.3
294294

295295
# pygame utils
296296
self.window = None
@@ -363,17 +363,17 @@ def step(self, a):
363363
# If we are in the fickle step, the passenger has been in the vehicle for at least a step and this step the
364364
# position changed
365365
if (
366-
self.fickle_step
366+
self.fickle_passenger
367+
and self.fickle_step
367368
and shadow_pass_loc == 4
368369
and (taxi_row != shadow_row or taxi_col != shadow_col)
369370
):
370371
self.fickle_step = False
371-
if self.fickle_passenger and self.np_random.random() < 0.3:
372-
possible_destinations = [
373-
i for i in range(len(self.locs)) if i != shadow_dest_idx
374-
]
375-
dest_idx = self.np_random.choice(possible_destinations)
376-
s = self.encode(taxi_row, taxi_col, pass_loc, dest_idx)
372+
possible_destinations = [
373+
i for i in range(len(self.locs)) if i != shadow_dest_idx
374+
]
375+
dest_idx = self.np_random.choice(possible_destinations)
376+
s = self.encode(taxi_row, taxi_col, pass_loc, dest_idx)
377377

378378
self.s = s
379379

@@ -391,7 +391,7 @@ def reset(
391391
super().reset(seed=seed)
392392
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
393393
self.lastaction = None
394-
self.fickle_step = True
394+
self.fickle_step = self.fickle_passenger and self.np_random.random() < 0.3
395395
self.taxi_orientation = 0
396396

397397
if self.render_mode == "human":

tests/envs/test_env_implementation.py

+24
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,30 @@ def test_taxi_disallowed_transitions():
263263
) not in disallowed_transitions
264264

265265

266+
def test_taxi_fickle_passenger():
267+
env = TaxiEnv(fickle_passenger=True)
268+
_, _ = env.reset()
269+
# Force passenger being in a fickle state
270+
env.fickle_step = True
271+
state, reward, done, _, _ = env.step(0)
272+
taxi_row, taxi_col, pass_idx, orig_dest_idx = env.decode(state)
273+
# force taxi to passenger location
274+
env.s = env.encode(
275+
env.locs[pass_idx][0], env.locs[pass_idx][1], pass_idx, orig_dest_idx
276+
)
277+
# pick up the passenger
278+
_, _, _, _, _ = env.step(4)
279+
if env.locs[pass_idx][0] == 0:
280+
# if we're on the top row, move down
281+
state, _, _, _, _ = env.step(0)
282+
else:
283+
# otherwise move up
284+
state, _, _, _, _ = env.step(1)
285+
taxi_row, taxi_col, pass_idx, dest_idx = env.decode(state)
286+
# check that passenger has changed their destination
287+
assert orig_dest_idx != dest_idx
288+
289+
266290
@pytest.mark.parametrize(
267291
"env_name",
268292
["Acrobot-v1", "CartPole-v1", "MountainCar-v0", "MountainCarContinuous-v0"],

0 commit comments

Comments
 (0)