1
- from dataclasses import dataclass
2
1
from typing import Dict
3
2
4
3
import tensordict
5
4
from torch .optim .swa_utils import AveragedModel , get_ema_avg_fn
6
5
7
6
from agents .utils .memory import NotReadyException
8
- from agents .utils .rl .rl import *
9
7
from agents .utils .return_estimator import ValueFetchMethod
8
+ from agents .utils .rl .rl import *
10
9
11
10
12
11
class DeepQTrainer (RLTrainer ):
13
12
@dataclass
14
13
class Configuration :
15
14
decay : float = 0.99
16
15
update_steps : int = 10
16
+ epochs : int = 1
17
17
18
18
@staticmethod
19
19
def from_cli (parameters : Dict ):
20
20
return DeepQTrainer .Configuration (
21
21
decay = parameters .get ('decay' , 0.99 ),
22
- update_steps = parameters .get ('update_steps' , 20 )
22
+ update_steps = parameters .get ('update_steps' , 20 ),
23
+ epochs = parameters .get ('epochs' , 1 )
23
24
)
24
25
25
26
def __init__ (self , configuration : Configuration , * args , ** kwargs ):
@@ -37,34 +38,35 @@ def configure(self, model: Policy):
37
38
self ._target_model = AveragedModel (model .clone (), avg_fn = avg_fn ).to (self .device )
38
39
39
40
def __train__ (self , model : Policy ):
40
- try :
41
- batch , info = self .storage .sample (device = self .device )
42
- except NotReadyException :
43
- return
41
+ for _ in range (self .configuration .epochs ):
42
+ try :
43
+ batch , info = self .storage .sample (device = self .device )
44
+ except NotReadyException :
45
+ return
44
46
45
- with torch .no_grad ():
46
- q_values = self .estimate_q (model , batch )
47
+ with torch .no_grad ():
48
+ q_values = self .estimate_q (model , batch )
47
49
48
- def compute_loss ():
49
- actions = self .__get_action_values__ (model , batch .state , batch .action )
50
+ def compute_loss ():
51
+ actions = self .__get_action_values__ (model , batch .state , batch .action )
50
52
51
- loss_ = self .loss (actions , q_values )
52
- td_error_ = torch .square (actions - q_values )
53
+ loss_ = self .loss (actions , q_values )
54
+ td_error_ = torch .square (actions - q_values )
53
55
54
- entropy = torch .distributions .Categorical (logits = actions ).entropy ().mean ()
56
+ entropy = torch .distributions .Categorical (logits = actions ).entropy ().mean ()
55
57
56
- return loss_ , (td_error_ , entropy )
58
+ return loss_ , (td_error_ , entropy )
57
59
58
- loss , result = self .step (compute_loss , self .optimizer )
60
+ loss , result = self .step (compute_loss , self .optimizer )
59
61
60
- td_error , entropy = result
61
- td_error_mean = td_error .mean ()
62
+ td_error , entropy = result
63
+ td_error_mean = td_error .mean ()
62
64
63
- self .record_loss (loss )
64
- self .record_loss (td_error_mean , key = 'td_error' )
65
- self .record_loss (entropy , key = 'entropy' )
65
+ self .record_loss (loss )
66
+ self .record_loss (td_error_mean , key = 'td_error' )
67
+ self .record_loss (entropy , key = 'entropy' )
66
68
67
- print (f'loss: { loss } , td_error: { td_error_mean } , entropy: { entropy } ' )
69
+ print (f'loss: { loss } , td_error: { td_error_mean } , entropy: { entropy } ' )
68
70
69
71
with torch .no_grad ():
70
72
if self .optimizer .step_count % self .configuration .update_steps == 0 :
0 commit comments