5
5
import numpy as np
6
6
import packaging
7
7
8
- from tianshou .env .pettingzoo_env import PettingZooEnv
9
8
from tianshou .env .utils import ENV_TYPE , gym_new_venv_step_type
10
9
from tianshou .env .worker import (
11
10
DummyEnvWorker ,
14
13
SubprocEnvWorker ,
15
14
)
16
15
16
+ try :
17
+ from tianshou .env .pettingzoo_env import PettingZooEnv
18
+ except ImportError :
19
+ PettingZooEnv = None # type: ignore
20
+
17
21
try :
18
22
import gym as old_gym
23
+
19
24
has_old_gym = True
20
25
except ImportError :
21
26
has_old_gym = False
@@ -152,11 +157,13 @@ def __init__(
152
157
153
158
self .env_num = len (env_fns )
154
159
self .wait_num = wait_num or len (env_fns )
155
- assert 1 <= self .wait_num <= len (env_fns ), \
156
- f"wait_num should be in [1, { len (env_fns )} ], but got { wait_num } "
160
+ assert (
161
+ 1 <= self .wait_num <= len (env_fns )
162
+ ), f"wait_num should be in [1, { len (env_fns )} ], but got { wait_num } "
157
163
self .timeout = timeout
158
- assert self .timeout is None or self .timeout > 0 , \
159
- f"timeout is { timeout } , it should be positive if provided!"
164
+ assert (
165
+ self .timeout is None or self .timeout > 0
166
+ ), f"timeout is { timeout } , it should be positive if provided!"
160
167
self .is_async = self .wait_num != len (env_fns ) or timeout is not None
161
168
self .waiting_conn : List [EnvWorker ] = []
162
169
# environments in self.ready_id is actually ready
@@ -169,8 +176,9 @@ def __init__(
169
176
self .is_closed = False
170
177
171
178
def _assert_is_not_closed (self ) -> None :
172
- assert not self .is_closed , \
173
- f"Methods of { self .__class__ .__name__ } cannot be called after close."
179
+ assert (
180
+ not self .is_closed
181
+ ), f"Methods of { self .__class__ .__name__ } cannot be called after close."
174
182
175
183
def __len__ (self ) -> int :
176
184
"""Return len(self), which is the number of environments."""
@@ -245,10 +253,12 @@ def _wrap_id(
245
253
246
254
def _assert_id (self , id : Union [List [int ], np .ndarray ]) -> None :
247
255
for i in id :
248
- assert i not in self .waiting_id , \
249
- f"Cannot interact with environment { i } which is stepping now."
250
- assert i in self .ready_id , \
251
- f"Can only interact with ready environments { self .ready_id } ."
256
+ assert (
257
+ i not in self .waiting_id
258
+ ), f"Cannot interact with environment { i } which is stepping now."
259
+ assert (
260
+ i in self .ready_id
261
+ ), f"Can only interact with ready environments { self .ready_id } ."
252
262
253
263
def reset (
254
264
self ,
@@ -271,9 +281,10 @@ def reset(
271
281
self .workers [i ].send (None , ** kwargs )
272
282
ret_list = [self .workers [i ].recv () for i in id ]
273
283
274
- assert isinstance (ret_list [0 ], (tuple , list )) and len (
275
- ret_list [0 ]
276
- ) == 2 and isinstance (ret_list [0 ][1 ], dict )
284
+ assert (
285
+ isinstance (ret_list [0 ], (tuple , list )) and len (ret_list [0 ]) == 2
286
+ and isinstance (ret_list [0 ][1 ], dict )
287
+ )
277
288
278
289
obs_list = [r [0 ] for r in ret_list ]
279
290
@@ -367,9 +378,13 @@ def step(
367
378
obs_stack = np .stack (obs_list )
368
379
except ValueError : # different len(obs)
369
380
obs_stack = np .array (obs_list , dtype = object )
370
- return obs_stack , np .stack (rew_list ), np .stack (term_list ), np .stack (
371
- trunc_list
372
- ), np .stack (info_list )
381
+ return (
382
+ obs_stack ,
383
+ np .stack (rew_list ),
384
+ np .stack (term_list ),
385
+ np .stack (trunc_list ),
386
+ np .stack (info_list ),
387
+ )
373
388
374
389
def seed (
375
390
self ,
0 commit comments