1
+ #!/usr/bin/python
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ import torch .nn as nn
5
+ import numpy as np
6
+ import numpy .random as rd
7
+ from copy import deepcopy
8
+ from core import *
9
+ from torch .utils .tensorboard import SummaryWriter
10
+
11
+ class AgentBase :
12
+ def __init__ (self ):
13
+ self .learning_rate = 1e-4
14
+ self .soft_update_tau = 2 ** - 8 # 5e-3 ~= 2 ** -8
15
+ self .state = None # set for self.update_buffer(), initialize before training
16
+ self .device = None
17
+
18
+ self .act = self .act_target = None
19
+ self .cri = self .cri_target = None
20
+ self .act_optimizer = None
21
+ self .cri_optimizer = None
22
+ self .criterion = None
23
+
24
+ self .writer = SummaryWriter ()
25
+ self .update_num = 0
26
+
27
+ def init (self , net_dim , state_dim , action_dim ):
28
+ """initialize the self.object in `__init__()`
29
+
30
+ replace by different DRL algorithms
31
+ explict call self.init() for multiprocessing.
32
+
33
+ :int net_dim: the dimension of networks (the width of neural networks)
34
+ :int state_dim: the dimension of state (the number of state vector)
35
+ :int action_dim: the dimension of action (the number of discrete action)
36
+ """
37
+
38
+ def select_action (self , state ) -> np .ndarray :
39
+ """Select actions for exploration
40
+
41
+ :array state: state.shape==(state_dim, )
42
+ :return array action: action.shape==(action_dim, ), (action.min(), action.max())==(-1, +1)
43
+ """
44
+ states = torch .as_tensor ((state ,), dtype = torch .float32 , device = self .device ).detach_ ()
45
+ action = self .act (states )[0 ]
46
+ return action .cpu ().numpy ()
47
+
48
+ def select_actions (self , states ) -> np .ndarray :
49
+ """Select actions for exploration
50
+
51
+ :array states: (state, ) or (state, state, ...) or state.shape==(n, *state_dim)
52
+ :return array action: action.shape==(-1, action_dim), (action.min(), action.max())==(-1, +1)
53
+ """
54
+ states = torch .as_tensor (states , dtype = torch .float32 , device = self .device ).detach_ ()
55
+ actions = self .act (states )
56
+ return actions .cpu ().numpy () # -1 < action < +1
57
+
58
+ def explore_env (self , env , buffer , target_step , reward_scale , gamma ) -> int :
59
+ """actor explores in env, then stores the env transition to ReplayBuffer
60
+
61
+ :env: RL training environment. env.reset() env.step()
62
+ :buffer: Experience Replay Buffer. buffer.append_buffer() buffer.extend_buffer()
63
+ :int target_step: explored target_step number of step in env
64
+ :float reward_scale: scale reward, 'reward * reward_scale'
65
+ :float gamma: discount factor, 'mask = 0.0 if done else gamma'
66
+ :return int target_step: collected target_step number of step in env
67
+ """
68
+ for _ in range (target_step ):
69
+ action = self .select_action (self .state )
70
+ next_s , reward , done , _ = env .step (action )
71
+ other = (reward * reward_scale , 0.0 if done else gamma , * action )
72
+ buffer .append_buffer (self .state , other )
73
+ self .state = env .reset () if done else next_s
74
+ return target_step
75
+
76
+ def update_net (self , buffer , target_step , batch_size , repeat_times ) -> (float , float ):
77
+ """update the neural network by sampling batch data from ReplayBuffer
78
+
79
+ replace by different DRL algorithms.
80
+ return the objective value as training information to help fine-tuning
81
+
82
+ :buffer: Experience replay buffer. buffer.append_buffer() buffer.extend_buffer()
83
+ :int target_step: explore target_step number of step in env
84
+ :int batch_size: sample batch_size of data for Stochastic Gradient Descent
85
+ :float repeat_times: the times of sample batch = int(target_step * repeat_times) in off-policy
86
+ :return float obj_a: the objective value of actor
87
+ :return float obj_c: the objective value of critic
88
+ """
89
+
90
+ def save_load_model (self , cwd , if_save ):
91
+ """save or load model files
92
+
93
+ :str cwd: current working directory, we save model file here
94
+ :bool if_save: save model or load model
95
+ """
96
+ act_save_path = '{}/actor.pth' .format (cwd )
97
+ cri_save_path = '{}/critic.pth' .format (cwd )
98
+
99
+ def load_torch_file (network , save_path ):
100
+ network_dict = torch .load (save_path , map_location = lambda storage , loc : storage )
101
+ network .load_state_dict (network_dict )
102
+
103
+ if if_save :
104
+ if self .act is not None :
105
+ torch .save (self .act .state_dict (), act_save_path )
106
+ if self .cri is not None :
107
+ torch .save (self .cri .state_dict (), cri_save_path )
108
+ elif (self .act is not None ) and os .path .exists (act_save_path ):
109
+ load_torch_file (self .act , act_save_path )
110
+ print ("Loaded act:" , cwd )
111
+ elif (self .cri is not None ) and os .path .exists (cri_save_path ):
112
+ load_torch_file (self .cri , cri_save_path )
113
+ print ("Loaded cri:" , cwd )
114
+ else :
115
+ print ("FileNotFound when load_model: {}" .format (cwd ))
116
+
117
+ @staticmethod
118
+ def soft_update (target_net , current_net , tau ):
119
+ """soft update a target network via current network
120
+
121
+ :nn.Module target_net: target network update via a current network, it is more stable
122
+ :nn.Module current_net: current network update via an optimizer
123
+ """
124
+ for tar , cur in zip (target_net .parameters (), current_net .parameters ()):
125
+ tar .data .copy_ (cur .data .__mul__ (tau ) + tar .data .__mul__ (1 - tau ))
126
+
127
+ class AgentDQN (AgentBase ):
128
+ def __init__ (self ):
129
+ super ().__init__ ()
130
+ self .explore_rate = 0.1 # the probability of choosing action randomly in epsilon-greedy
131
+ self .action_dim = None # chose discrete action randomly in epsilon-greedy
132
+
133
+ def init (self , net_dim , state_dim , action_dim ):
134
+ self .action_dim = action_dim
135
+ self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
136
+
137
+ self .cri = QNet (net_dim , state_dim , action_dim ).to (self .device )
138
+ self .cri_target = deepcopy (self .cri )
139
+ self .act = self .cri # to keep the same from Actor-Critic framework
140
+
141
+ self .criterion = torch .torch .nn .MSELoss ()
142
+ self .cri_optimizer = torch .optim .Adam (self .cri .parameters (), lr = self .learning_rate )
143
+
144
+ def select_action (self , state ) -> int : # for discrete action space
145
+ if rd .rand () < self .explore_rate : # epsilon-greedy
146
+ a_int = rd .randint (self .action_dim ) # choosing action randomly
147
+ else :
148
+ states = torch .as_tensor ((state ,), dtype = torch .float32 , device = self .device ).detach_ ()
149
+ action = self .act (states )[0 ]
150
+ a_int = action .argmax (dim = 0 ).cpu ().numpy ()
151
+ return a_int
152
+
153
+ def explore_env (self , env , buffer , target_step , reward_scale , gamma ) -> int :
154
+ for _ in range (target_step ):
155
+ action = self .select_action (self .state )
156
+ next_s , reward , done , _ = env .step (action )
157
+
158
+ other = (reward * reward_scale , 0.0 if done else gamma , action ) # action is an int
159
+ buffer .append_buffer (self .state , other )
160
+ self .state = env .reset () if done else next_s
161
+ return target_step
162
+
163
+ def update_net (self , buffer , target_step , batch_size , repeat_times ) -> (float , float ):
164
+ buffer .update_now_len_before_sample ()
165
+
166
+ next_q = obj_critic = None
167
+ for _ in range (int (target_step * repeat_times )):
168
+ with torch .no_grad ():
169
+ reward , mask , action , state , next_s = buffer .sample_batch (batch_size ) # next_state
170
+ next_q = self .cri_target (next_s ).max (dim = 1 , keepdim = True )[0 ]
171
+ q_label = reward + mask * next_q
172
+ q_eval = self .cri (state ).gather (1 , action .type (torch .long ))
173
+ obj_critic = self .criterion (q_eval , q_label )
174
+
175
+ self .cri_optimizer .zero_grad ()
176
+ obj_critic .backward ()
177
+ self .cri_optimizer .step ()
178
+ self .soft_update (self .cri_target , self .cri , self .soft_update_tau )
179
+ return next_q .mean ().item (), obj_critic .item ()
180
+
181
+ class AgentDoubleDQN (AgentDQN ):
182
+ def __init__ (self ):
183
+ super ().__init__ ()
184
+ self .explore_rate = 0.25 # the probability of choosing action randomly in epsilon-greedy
185
+ self .softmax = torch .nn .Softmax (dim = 1 )
186
+
187
+ def init (self , net_dim , state_dim , action_dim ):
188
+ self .action_dim = action_dim
189
+ self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
190
+
191
+ self .cri = QNetTwin (net_dim , state_dim , action_dim ).to (self .device )
192
+ self .cri_target = deepcopy (self .cri )
193
+ self .act = self .cri
194
+
195
+ self .criterion = torch .nn .SmoothL1Loss ()
196
+ self .cri_optimizer = torch .optim .Adam (self .act .parameters (), lr = self .learning_rate )
197
+
198
+ def select_action (self , state ) -> int : # for discrete action space
199
+ states = torch .as_tensor ((state ,), dtype = torch .float32 , device = self .device ).detach_ ()
200
+ actions = self .act (states )
201
+ if rd .rand () < self .explore_rate : # epsilon-greedy
202
+ action = self .softmax (actions )[0 ]
203
+ a_prob = action .detach ().cpu ().numpy () # choose action according to Q value
204
+ a_int = rd .choice (self .action_dim , p = a_prob )
205
+ else :
206
+ action = actions [0 ]
207
+ a_int = action .argmax (dim = 0 ).cpu ().numpy ()
208
+ return a_int
209
+
210
+ def update_net (self , buffer , target_step , batch_size , repeat_times ) -> (float , float ):
211
+ """Contribution of DDQN (Double DQN)
212
+
213
+ Twin Q-Network. Use min(q1, q2) to reduce over-estimation.
214
+ """
215
+ buffer .update_now_len_before_sample ()
216
+
217
+ next_q = obj_critic = None
218
+ for _ in range (int (target_step * repeat_times )):
219
+ with torch .no_grad ():
220
+ reward , mask , action , state , next_s = buffer .sample_batch (batch_size )
221
+ next_q = torch .min (* self .cri_target .get_q1_q2 (next_s ))
222
+ next_q = next_q .max (dim = 1 , keepdim = True )[0 ]
223
+ q_label = reward + mask * next_q
224
+ act_int = action .type (torch .long )
225
+ q1 , q2 = [qs .gather (1 , act_int ) for qs in self .act .get_q1_q2 (state )]
226
+ obj_critic = self .criterion (q1 , q_label ) + self .criterion (q2 , q_label )
227
+
228
+ self .cri_optimizer .zero_grad ()
229
+ obj_critic .backward ()
230
+ self .cri_optimizer .step ()
231
+ self .soft_update (self .cri_target , self .cri , self .soft_update_tau )
232
+
233
+ self .update_num += 1
234
+ self .writer .add_scalar ('loss_Q' , obj_critic , self .update_num )
235
+ return next_q .mean ().item (), obj_critic .item () / 2
236
+
237
+ class AgentD3QN (AgentDoubleDQN ): # D3QN: Dueling Double DQN
238
+ def __init__ (self ):
239
+ super ().__init__ ()
240
+
241
+ def init (self , net_dim , state_dim , action_dim ):
242
+ """Contribution of D3QN (Dueling Double DQN)
243
+
244
+ There are not contribution of D3QN.
245
+ Obviously, DoubleDQN is compatible with DuelingDQN.
246
+ Any beginner can come up with this idea (D3QN) independently.
247
+ """
248
+ self .action_dim = action_dim
249
+ self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
250
+
251
+ self .cri = QNetTwinDuel (net_dim , state_dim , action_dim ).to (self .device )
252
+ self .cri_target = deepcopy (self .cri )
253
+ self .act = self .cri
254
+
255
+ self .criterion = torch .nn .SmoothL1Loss ()
256
+ self .cri_optimizer = torch .optim .Adam (self .act .parameters (), lr = self .learning_rate )
257
+
258
+ if __name__ == "__main__" :
259
+ agent = AgentD3QN ()
260
+ agent .init (128 ,3 ,1 )
0 commit comments