Skip to content

Commit f0afdea

Browse files
authored
update version to 0.5.0 (#826)
1 parent 73600ed commit f0afdea

10 files changed

+52
-63
lines changed

.github/ISSUE_TEMPLATE.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
88
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
99
```python
10-
import tianshou, gym, torch, numpy, sys
10+
import tianshou, gymnasium as gym, torch, numpy, sys
1111
print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
1212
```

.github/workflows/extra_sys.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ jobs:
1212
python-version: [3.7, 3.8]
1313
steps:
1414
- name: Cancel previous run
15-
uses: styfle/cancel-workflow-action@0.9.1
15+
uses: styfle/cancel-workflow-action@0.11.0
1616
with:
1717
access_token: ${{ github.token }}
18-
- uses: actions/checkout@v2
18+
- uses: actions/checkout@v3
1919
- name: Set up Python ${{ matrix.python-version }}
20-
uses: actions/setup-python@v2
20+
uses: actions/setup-python@v4
2121
with:
2222
python-version: ${{ matrix.python-version }}
2323
- name: Upgrade pip

.github/workflows/gputest.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ jobs:
88
if: "!contains(github.event.head_commit.message, 'ci skip')"
99
steps:
1010
- name: Cancel previous run
11-
uses: styfle/cancel-workflow-action@0.9.1
11+
uses: styfle/cancel-workflow-action@0.11.0
1212
with:
1313
access_token: ${{ github.token }}
14-
- uses: actions/checkout@v2
14+
- uses: actions/checkout@v3
1515
- name: Set up Python 3.8
16-
uses: actions/setup-python@v2
16+
uses: actions/setup-python@v4
1717
with:
1818
python-version: 3.8
1919
- name: Upgrade pip

.github/workflows/lint_and_docs.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ jobs:
77
runs-on: ubuntu-latest
88
steps:
99
- name: Cancel previous run
10-
uses: styfle/cancel-workflow-action@0.9.1
10+
uses: styfle/cancel-workflow-action@0.11.0
1111
with:
1212
access_token: ${{ github.token }}
13-
- uses: actions/checkout@v2
13+
- uses: actions/checkout@v3
1414
- name: Set up Python 3.8
15-
uses: actions/setup-python@v2
15+
uses: actions/setup-python@v4
1616
with:
1717
python-version: 3.8
1818
- name: Upgrade pip

.github/workflows/profile.yml

-27
This file was deleted.

.github/workflows/pytest.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ jobs:
1111
python-version: [3.7, 3.8, 3.9]
1212
steps:
1313
- name: Cancel previous run
14-
uses: styfle/cancel-workflow-action@0.9.1
14+
uses: styfle/cancel-workflow-action@0.11.0
1515
with:
1616
access_token: ${{ github.token }}
17-
- uses: actions/checkout@v2
17+
- uses: actions/checkout@v3
1818
- name: Set up Python ${{ matrix.python-version }}
19-
uses: actions/setup-python@v2
19+
uses: actions/setup-python@v4
2020
with:
2121
python-version: ${{ matrix.python-version }}
2222
- name: Upgrade pip

docs/requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
gym
22
numba
33
numpy>=1.20
4-
sphinx<4
4+
sphinx
55
sphinxcontrib-bibtex
66
tensorboard
77
torch
88
tqdm
9-
protobuf~=3.19.0
9+
protobuf
1010
pettingzoo

test/offline/test_discrete_bcq.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_args():
4040
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
4141
parser.add_argument("--test-num", type=int, default=100)
4242
parser.add_argument("--logdir", type=str, default="log")
43-
parser.add_argument("--render", type=float, default=0.)
43+
parser.add_argument("--render", type=float, default=0.0)
4444
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
4545
parser.add_argument(
4646
"--device",
@@ -59,7 +59,7 @@ def test_discrete_bcq(args=get_args()):
5959
args.state_shape = env.observation_space.shape or env.observation_space.n
6060
args.action_shape = env.action_space.shape or env.action_space.n
6161
if args.reward_threshold is None:
62-
default_reward_threshold = {"CartPole-v0": 190}
62+
default_reward_threshold = {"CartPole-v0": 185}
6363
args.reward_threshold = default_reward_threshold.get(
6464
args.task, env.spec.reward_threshold
6565
)
@@ -123,7 +123,8 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
123123
{
124124
"model": policy.state_dict(),
125125
"optim": optim.state_dict(),
126-
}, ckpt_path
126+
},
127+
ckpt_path,
127128
)
128129
return ckpt_path
129130

tianshou/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from tianshou import data, env, exploration, policy, trainer, utils
22

3-
__version__ = "0.4.11"
3+
__version__ = "0.5.0"
44

55
__all__ = [
66
"env",

tianshou/env/venvs.py

+32-17
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
import packaging
77

8-
from tianshou.env.pettingzoo_env import PettingZooEnv
98
from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type
109
from tianshou.env.worker import (
1110
DummyEnvWorker,
@@ -14,8 +13,14 @@
1413
SubprocEnvWorker,
1514
)
1615

16+
try:
17+
from tianshou.env.pettingzoo_env import PettingZooEnv
18+
except ImportError:
19+
PettingZooEnv = None # type: ignore
20+
1721
try:
1822
import gym as old_gym
23+
1924
has_old_gym = True
2025
except ImportError:
2126
has_old_gym = False
@@ -152,11 +157,13 @@ def __init__(
152157

153158
self.env_num = len(env_fns)
154159
self.wait_num = wait_num or len(env_fns)
155-
assert 1 <= self.wait_num <= len(env_fns), \
156-
f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
160+
assert (
161+
1 <= self.wait_num <= len(env_fns)
162+
), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
157163
self.timeout = timeout
158-
assert self.timeout is None or self.timeout > 0, \
159-
f"timeout is {timeout}, it should be positive if provided!"
164+
assert (
165+
self.timeout is None or self.timeout > 0
166+
), f"timeout is {timeout}, it should be positive if provided!"
160167
self.is_async = self.wait_num != len(env_fns) or timeout is not None
161168
self.waiting_conn: List[EnvWorker] = []
162169
# environments in self.ready_id is actually ready
@@ -169,8 +176,9 @@ def __init__(
169176
self.is_closed = False
170177

171178
def _assert_is_not_closed(self) -> None:
172-
assert not self.is_closed, \
173-
f"Methods of {self.__class__.__name__} cannot be called after close."
179+
assert (
180+
not self.is_closed
181+
), f"Methods of {self.__class__.__name__} cannot be called after close."
174182

175183
def __len__(self) -> int:
176184
"""Return len(self), which is the number of environments."""
@@ -245,10 +253,12 @@ def _wrap_id(
245253

246254
def _assert_id(self, id: Union[List[int], np.ndarray]) -> None:
247255
for i in id:
248-
assert i not in self.waiting_id, \
249-
f"Cannot interact with environment {i} which is stepping now."
250-
assert i in self.ready_id, \
251-
f"Can only interact with ready environments {self.ready_id}."
256+
assert (
257+
i not in self.waiting_id
258+
), f"Cannot interact with environment {i} which is stepping now."
259+
assert (
260+
i in self.ready_id
261+
), f"Can only interact with ready environments {self.ready_id}."
252262

253263
def reset(
254264
self,
@@ -271,9 +281,10 @@ def reset(
271281
self.workers[i].send(None, **kwargs)
272282
ret_list = [self.workers[i].recv() for i in id]
273283

274-
assert isinstance(ret_list[0], (tuple, list)) and len(
275-
ret_list[0]
276-
) == 2 and isinstance(ret_list[0][1], dict)
284+
assert (
285+
isinstance(ret_list[0], (tuple, list)) and len(ret_list[0]) == 2
286+
and isinstance(ret_list[0][1], dict)
287+
)
277288

278289
obs_list = [r[0] for r in ret_list]
279290

@@ -367,9 +378,13 @@ def step(
367378
obs_stack = np.stack(obs_list)
368379
except ValueError: # different len(obs)
369380
obs_stack = np.array(obs_list, dtype=object)
370-
return obs_stack, np.stack(rew_list), np.stack(term_list), np.stack(
371-
trunc_list
372-
), np.stack(info_list)
381+
return (
382+
obs_stack,
383+
np.stack(rew_list),
384+
np.stack(term_list),
385+
np.stack(trunc_list),
386+
np.stack(info_list),
387+
)
373388

374389
def seed(
375390
self,

0 commit comments

Comments
 (0)