forked from vuoristo/dqn-agent
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathGazeboAgent.py
37 lines (31 loc) · 1.19 KB
/
GazeboAgent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import gym
import jde_gym_gazebo
import argparse
from ConvModel import ConvModel
from DQNAgent import DQNAgent
ENV_NAME = 'jde-gazebo-kobuki-rgb-v0'
def main():
parser = argparse.ArgumentParser('Train or Evaluate a DQN Agent for OpenAI '
'Gym Atari Environments')
parser.add_argument('--env', '-e', default=ENV_NAME)
parser.add_argument('--evaluate', action='store_true', default=False)
parser.add_argument('--load_weights', '-l', default=None)
parser.add_argument('--render', '-r', action='store_true', default=False)
args = parser.parse_args()
env_name = args.env
weights_to_load = args.load_weights
evaluate = args.evaluate
render = args.render
env = gym.make(env_name)
model = ConvModel(env, learning_rate=2.5e-4, momentum=0.95, gamma=0.99,
tau=0.01, soft_updates=True, weights_to_load=weights_to_load, grayscale=False, window_size=8)
agent = DQNAgent(env, model, linear_epsilon_decay=True,
epsilon_decay_steps=3.e6, epsilon=1.0, min_epsilon=0.06,
exp_buffer_size=1000000, batch_size=256, render=render,
update_freq=1, random_starts=30,max_steps=10000)
if evaluate:
agent.evaluate()
else:
agent.train()
if __name__ == '__main__':
main()