-
Notifications
You must be signed in to change notification settings - Fork 1
/
DoubleDQN.py
139 lines (105 loc) · 4.39 KB
/
DoubleDQN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import numpy as np
import random
from collections import deque
from random import randrange
class BasicBuffer:
def __init__(self, max_size):
self.max_size = max_size
self.buffer = deque(maxlen=max_size)
def push(self, state, action, reward, next_state, done):
experience = (state, action, np.array([reward]), next_state, done)
self.buffer.append(experience)
def sample(self, batch_size):
state_batch = []
action_batch = []
reward_batch = []
next_state_batch = []
done_batch = []
batch = random.sample(self.buffer, batch_size)
for experience in batch:
state, action, reward, next_state, done = experience
state_batch.append(state)
action_batch.append(action)
reward_batch.append(reward)
next_state_batch.append(next_state)
done_batch.append(done)
return (state_batch, action_batch, reward_batch, next_state_batch, done_batch)
def __len__(self):
return len(self.buffer)
class DQN(nn.Module):
def __init__(self, input_dim, output_dim):
super(DQN, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.fc = nn.Sequential(
nn.Linear(self.input_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, self.output_dim)
)
def forward(self, state):
qvals = self.fc(state)
return qvals
class DDQNAgent:
def __init__(self, num_cluster, num_resource, learning_rate=3e-4, gamma=0.99, tau=0.01, buffer_size=10000):
self.state_size = num_cluster*num_resource + num_resource
self.action_space = num_cluster+1
self.learning_rate = learning_rate
self.gamma = gamma
self.tau = tau
self.loss = []
self.replay_buffer = BasicBuffer(max_size=buffer_size)
self.device = "cpu"
if torch.cuda.is_available():
self.device = "cuda"
self.model = DQN(self.state_size, self.action_space).to(self.device)
self.target_model = DQN(self.state_size, self.action_space).to(self.device)
# hard copy model parameters to target model parameters
for target_param, param in zip(self.model.parameters(), self.target_model.parameters()):
target_param.data.copy_(param)
self.optimizer = torch.optim.Adam(self.model.parameters())
def get_action(self, state, eps=0.20):
state = torch.FloatTensor(state).float().unsqueeze(0).to(self.device)
qvals = self.model.forward(state)
action = np.argmax(qvals.cpu().detach().numpy())
if(np.random.randn() < eps):
return randrange(self.action_space)
return action
def compute_loss(self, batch):
states, actions, rewards, next_states, dones = batch
states = torch.FloatTensor(states).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones)
# resize tensors
actions = actions.view(actions.size(0), 1)
dones = dones.view(dones.size(0), 1)
# compute loss
curr_Q_s = self.model.forward(states).gather(1, actions)
curr_Q = self.model.forward(next_states)
max_curr_Q = torch.argmax(curr_Q, 1).view(actions.size(0), 1)
next_Q = self.target_model.forward(next_states).gather(1,max_curr_Q)
expected_Q = rewards + self.gamma * next_Q
loss = F.mse_loss(curr_Q_s, expected_Q.detach())
self.loss.append(loss)
return loss
def update(self, batch_size):
batch = self.replay_buffer.sample(batch_size)
loss = self.compute_loss(batch)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def get_loss(self):
episode_loss = sum(self.loss)/len(self.loss)
self.loss = []
return episode_loss
def update_target(self):
# target network update
for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)