1
1
import numpy as np
2
+ import torch
2
3
from gymnasium .wrappers import FlattenObservation
3
4
4
5
from openrl .configs .config import create_config_parser
5
6
from openrl .envs .common import make
6
7
from openrl .envs .wrappers .base_wrapper import BaseWrapper
7
- from openrl .envs .wrappers .extra_wrappers import FrameSkip , GIFWrapper
8
+ from openrl .envs .wrappers .extra_wrappers import (
9
+ ConvertEmptyBoxWrapper ,
10
+ FrameSkip ,
11
+ GIFWrapper ,
12
+ )
8
13
from openrl .modules .common import PPONet as Net
9
14
from openrl .runners .common import PPOAgent as Agent
10
15
@@ -18,15 +23,15 @@ def train():
18
23
cfg = cfg_parser .parse_args (["--config" , "ppo.yaml" ])
19
24
20
25
# create environment, set environment parallelism to 64
26
+ env_num = 64
21
27
env = make (
22
28
env_name ,
23
- env_num = 64 ,
24
- cfg = cfg ,
29
+ env_num = env_num ,
25
30
asynchronous = True ,
26
- env_wrappers = [FrameSkip , FlattenObservation ],
31
+ env_wrappers = [FrameSkip , FlattenObservation , ConvertEmptyBoxWrapper ],
27
32
)
28
33
29
- net = Net (env , cfg = cfg , device = "cuda" )
34
+ net = Net (env , cfg = cfg , device = "cuda" if torch . cuda . is_available () else "cpu" )
30
35
# initialize the trainer
31
36
agent = Agent (
32
37
net ,
@@ -44,18 +49,18 @@ def evaluation():
44
49
# begin to test
45
50
# Create an environment for testing and set the number of environments to interact with to 4. Set rendering mode to group_rgb_array.
46
51
render_mode = "group_rgb_array"
52
+
47
53
env = make (
48
54
env_name ,
49
55
render_mode = render_mode ,
50
56
env_num = 4 ,
51
57
asynchronous = True ,
52
- env_wrappers = [FrameSkip , FlattenObservation ],
53
- cfg = cfg ,
58
+ env_wrappers = [FrameSkip , FlattenObservation , ConvertEmptyBoxWrapper ],
54
59
)
55
60
# Wrap the environment with GIFWrapper to record the GIF, and set the frame rate to 5.
56
61
env = GIFWrapper (env , gif_path = "./new.gif" , fps = 5 )
57
62
58
- net = Net (env , cfg = cfg , device = "cuda" )
63
+ net = Net (env , cfg = cfg , device = "cuda" if torch . cuda . is_available () else "cpu" )
59
64
# initialize the trainer
60
65
agent = Agent (
61
66
net ,
0 commit comments