1
+ import copy
1
2
from typing import Dict
2
3
3
4
import tensordict
@@ -36,7 +37,7 @@ def configure(self, model: Policy):
36
37
37
38
avg_fn = get_ema_avg_fn (self .configuration .decay )
38
39
39
- self ._target_model = AveragedModel (model .clone (), avg_fn = avg_fn ).to (self .device )
40
+ self ._target_model = copy . deepcopy ( model ) # AveragedModel(model.clone(), avg_fn=avg_fn).to(self.device)
40
41
41
42
def __train__ (self , model : Policy ):
42
43
for _ in range (self .configuration .epochs ):
@@ -54,8 +55,6 @@ def compute_loss():
54
55
weight = torch .tensor (info ['_weight' ]) if '_weight' in info .keys () else torch .ones_like (q_values )
55
56
weight = weight .to (actions .device )
56
57
57
- print (weight , info )
58
-
59
58
loss_ = (self .loss (actions , q_values ) * weight ).mean ()
60
59
td_error_ = torch .square (actions - q_values )
61
60
@@ -76,7 +75,7 @@ def compute_loss():
76
75
77
76
with torch .no_grad ():
78
77
if self .optimizer .step_count % self .configuration .update_steps == 0 :
79
- self ._target_model . update_parameters (model )
78
+ self ._target_model = copy . deepcopy (model )
80
79
81
80
self .storage .update_priority (info ['index' ], td_error )
82
81
@@ -100,7 +99,7 @@ def __get_action_values__(model: Policy, state, actions):
100
99
101
100
@property
102
101
def target_model (self ):
103
- return self ._target_model . module
102
+ return self ._target_model
104
103
105
104
def state_dict (self ):
106
105
state_dict = super ().state_dict ()
0 commit comments