-
Notifications
You must be signed in to change notification settings - Fork 156
/
Copy pathddpg.py
166 lines (127 loc) · 4.84 KB
/
ddpg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from model import (Actor, Critic)
from memory import SequentialMemory
from random_process import OrnsteinUhlenbeckProcess
from util import *
# from ipdb import set_trace as debug
criterion = nn.MSELoss()
class DDPG(object):
def __init__(self, nb_states, nb_actions, args):
if args.seed > 0:
self.seed(args.seed)
self.nb_states = nb_states
self.nb_actions= nb_actions
# Create Actor and Critic Network
net_cfg = {
'hidden1':args.hidden1,
'hidden2':args.hidden2,
'init_w':args.init_w
}
self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg)
self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg)
self.actor_optim = Adam(self.actor.parameters(), lr=args.prate)
self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg)
self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg)
self.critic_optim = Adam(self.critic.parameters(), lr=args.rate)
hard_update(self.actor_target, self.actor) # Make sure target is with the same weight
hard_update(self.critic_target, self.critic)
#Create replay buffer
self.memory = SequentialMemory(limit=args.rmsize, window_length=args.window_length)
self.random_process = OrnsteinUhlenbeckProcess(size=nb_actions, theta=args.ou_theta, mu=args.ou_mu, sigma=args.ou_sigma)
# Hyper-parameters
self.batch_size = args.bsize
self.tau = args.tau
self.discount = args.discount
self.depsilon = 1.0 / args.epsilon
#
self.epsilon = 1.0
self.s_t = None # Most recent state
self.a_t = None # Most recent action
self.is_training = True
#
if USE_CUDA: self.cuda()
def update_policy(self):
# Sample batch
state_batch, action_batch, reward_batch, \
next_state_batch, terminal_batch = self.memory.sample_and_split(self.batch_size)
# Prepare for the target q batch
next_q_values = self.critic_target([
to_tensor(next_state_batch, volatile=True),
self.actor_target(to_tensor(next_state_batch, volatile=True)),
])
next_q_values.volatile=False
target_q_batch = to_tensor(reward_batch) + \
self.discount*to_tensor(terminal_batch.astype(np.float))*next_q_values
# Critic update
self.critic.zero_grad()
q_batch = self.critic([ to_tensor(state_batch), to_tensor(action_batch) ])
value_loss = criterion(q_batch, target_q_batch)
value_loss.backward()
self.critic_optim.step()
# Actor update
self.actor.zero_grad()
policy_loss = -self.critic([
to_tensor(state_batch),
self.actor(to_tensor(state_batch))
])
policy_loss = policy_loss.mean()
policy_loss.backward()
self.actor_optim.step()
# Target update
soft_update(self.actor_target, self.actor, self.tau)
soft_update(self.critic_target, self.critic, self.tau)
def eval(self):
self.actor.eval()
self.actor_target.eval()
self.critic.eval()
self.critic_target.eval()
def cuda(self):
self.actor.cuda()
self.actor_target.cuda()
self.critic.cuda()
self.critic_target.cuda()
def observe(self, r_t, s_t1, done):
if self.is_training:
self.memory.append(self.s_t, self.a_t, r_t, done)
self.s_t = s_t1
def random_action(self):
action = np.random.uniform(-1.,1.,self.nb_actions)
self.a_t = action
return action
def select_action(self, s_t, decay_epsilon=True):
action = to_numpy(
self.actor(to_tensor(np.array([s_t])))
).squeeze(0)
action += self.is_training*max(self.epsilon, 0)*self.random_process.sample()
action = np.clip(action, -1., 1.)
if decay_epsilon:
self.epsilon -= self.depsilon
self.a_t = action
return action
def reset(self, obs):
self.s_t = obs
self.random_process.reset_states()
def load_weights(self, output):
if output is None: return
self.actor.load_state_dict(
torch.load('{}/actor.pkl'.format(output))
)
self.critic.load_state_dict(
torch.load('{}/critic.pkl'.format(output))
)
def save_model(self,output):
torch.save(
self.actor.state_dict(),
'{}/actor.pkl'.format(output)
)
torch.save(
self.critic.state_dict(),
'{}/critic.pkl'.format(output)
)
def seed(self,s):
torch.manual_seed(s)
if USE_CUDA:
torch.cuda.manual_seed(s)