-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_rllib.py
71 lines (58 loc) · 2.07 KB
/
main_rllib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Author : James Arambam
Date : 18 Nov 2021
Description :
Input :
Output :
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""
# ================================ Imports ================================ #
import sys
sys.dont_write_bytecode = True
import os
import time
from ipdb import set_trace
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.pg import PGTrainer
from gym_minigrid.register import register
import gym
from gym_minigrid.envs.rooms_james import TwoRoomsEnv
from rllib_agents.myagent import MyTrainer
# =============================== Variables ================================== #
import gym, ray
from ray.rllib.agents import ppo
# ============================================================================ #
def main():
# Configure the algorithm.
config = {
"env": 'MiniGrid-TwoRooms-v0',
"env_config": {},
# Use 2 environment workers (aka "rollout workers") that parallelly
# collect samples from their own environment clone(s).
"num_workers": 1,
"framework": "torch",
# Tweak the default model provided automatically by RLlib,
# given the environment's observation- and action spaces.
"model": {
"fcnet_hiddens": [64, 64],
"fcnet_activation": "relu",
},
# Set up a separate evaluation worker set for the
# `trainer.evaluate()` call after training (see below).
"evaluation_num_workers": 1,
# Only for evaluation runs, render the env.
"evaluation_config": {
"render_env": False,
}
}
trainer = ppo.PPOTrainer(env=TwoRoomsEnv, config=config)
# trainer = PGTrainer(env=TwoRoomsEnv, config=config)
# trainer = MyTrainer(env=TwoRoomsEnv, config=config)
while True:
print(trainer.train())
# Create our RLlib Trainer.
# trainer = PPOTrainer(config=config)
# =============================================================================== #
if __name__ == '__main__':
main()