Skip to content

Commit 20cf3d3

Browse files
authored
merge dm_control to gymnasium
merge dm_control to gymnasium
2 parents d378b77 + 0491ad4 commit 20cf3d3

File tree

5 files changed

+14
-59
lines changed

5 files changed

+14
-59
lines changed

examples/dm_control/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Installation
22
```bash
3-
pip install shimmy[dm-control]
3+
pip install "shimmy[dm-control]"
44
```
55

66
## Usage

examples/dm_control/train_ppo.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import numpy as np
2+
import torch
23
from gymnasium.wrappers import FlattenObservation
34

45
from openrl.configs.config import create_config_parser
56
from openrl.envs.common import make
67
from openrl.envs.wrappers.base_wrapper import BaseWrapper
7-
from openrl.envs.wrappers.extra_wrappers import FrameSkip, GIFWrapper
8+
from openrl.envs.wrappers.extra_wrappers import (
9+
ConvertEmptyBoxWrapper,
10+
FrameSkip,
11+
GIFWrapper,
12+
)
813
from openrl.modules.common import PPONet as Net
914
from openrl.runners.common import PPOAgent as Agent
1015

@@ -18,15 +23,15 @@ def train():
1823
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
1924

2025
# create environment, set environment parallelism to 64
26+
env_num = 64
2127
env = make(
2228
env_name,
23-
env_num=64,
24-
cfg=cfg,
29+
env_num=env_num,
2530
asynchronous=True,
26-
env_wrappers=[FrameSkip, FlattenObservation],
31+
env_wrappers=[FrameSkip, FlattenObservation, ConvertEmptyBoxWrapper],
2732
)
2833

29-
net = Net(env, cfg=cfg, device="cuda")
34+
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
3035
# initialize the trainer
3136
agent = Agent(
3237
net,
@@ -44,18 +49,18 @@ def evaluation():
4449
# begin to test
4550
# Create an environment for testing and set the number of environments to interact with to 4. Set rendering mode to group_rgb_array.
4651
render_mode = "group_rgb_array"
52+
4753
env = make(
4854
env_name,
4955
render_mode=render_mode,
5056
env_num=4,
5157
asynchronous=True,
52-
env_wrappers=[FrameSkip, FlattenObservation],
53-
cfg=cfg,
58+
env_wrappers=[FrameSkip, FlattenObservation, ConvertEmptyBoxWrapper],
5459
)
5560
# Wrap the environment with GIFWrapper to record the GIF, and set the frame rate to 5.
5661
env = GIFWrapper(env, gif_path="./new.gif", fps=5)
5762

58-
net = Net(env, cfg=cfg, device="cuda")
63+
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
5964
# initialize the trainer
6065
agent = Agent(
6166
net,

openrl/envs/common/registration.py

-7
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,6 @@ def make(
7878
env_fns = make_snake_envs(
7979
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
8080
)
81-
82-
elif id.startswith("dm_control/"):
83-
from openrl.envs.dmc import make_dmc_envs
84-
85-
env_fns = make_dmc_envs(
86-
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
87-
)
8881
elif id.startswith("GymV21Environment-v0:") or id.startswith(
8982
"GymV26Environment-v0:"
9083
):

openrl/envs/dmc/__init__.py

-30
This file was deleted.

openrl/envs/dmc/dmc_env.py

-13
This file was deleted.

0 commit comments

Comments
 (0)