diff --git a/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py b/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py index 358067f..c18553c 100644 --- a/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py +++ b/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py @@ -51,7 +51,7 @@ def __init__( self.policy_update_freq = config.policy_update_freq self.target_update_freq = config.target_update_freq - self.target_entropy = -np.prod(self.actor_net.num_actions) + self.target_entropy = -self.action_num self.actor_net_optimiser = torch.optim.Adam( self.actor_net.parameters(), lr=config.actor_lr diff --git a/cares_reinforcement_learning/algorithm/policy/LAPSAC.py b/cares_reinforcement_learning/algorithm/policy/LAPSAC.py index 6607ee9..da307a8 100644 --- a/cares_reinforcement_learning/algorithm/policy/LAPSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/LAPSAC.py @@ -43,7 +43,7 @@ def __init__( self.policy_update_freq = config.policy_update_freq self.target_update_freq = config.target_update_freq - self.target_entropy = -np.prod(self.actor_net.num_actions) + self.target_entropy = -self.actor_net.num_actions self.actor_net_optimiser = torch.optim.Adam( self.actor_net.parameters(), lr=config.actor_lr diff --git a/cares_reinforcement_learning/algorithm/policy/PERSAC.py b/cares_reinforcement_learning/algorithm/policy/PERSAC.py index 72ac243..714becc 100644 --- a/cares_reinforcement_learning/algorithm/policy/PERSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/PERSAC.py @@ -44,7 +44,7 @@ def __init__( self.policy_update_freq = config.policy_update_freq self.target_update_freq = config.target_update_freq - self.target_entropy = -np.prod(self.actor_net.num_actions) + self.target_entropy = -self.actor_net.num_actions self.actor_net_optimiser = torch.optim.Adam( self.actor_net.parameters(), lr=config.actor_lr diff --git a/cares_reinforcement_learning/algorithm/policy/RDSAC.py b/cares_reinforcement_learning/algorithm/policy/RDSAC.py index 75d345f..73cee85 100644 --- a/cares_reinforcement_learning/algorithm/policy/RDSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/RDSAC.py @@ -38,7 +38,7 @@ def __init__( self.policy_update_freq = config.policy_update_freq self.target_update_freq = config.target_update_freq - self.target_entropy = -np.prod(self.actor_net.num_actions) + self.target_entropy = -self.actor_net.num_actions # RD-PER parameters self.scale_r = 1.0 diff --git a/cares_reinforcement_learning/algorithm/policy/SAC.py b/cares_reinforcement_learning/algorithm/policy/SAC.py index 77f4348..f185e64 100644 --- a/cares_reinforcement_learning/algorithm/policy/SAC.py +++ b/cares_reinforcement_learning/algorithm/policy/SAC.py @@ -45,7 +45,7 @@ def __init__( self.policy_update_freq = config.policy_update_freq self.target_update_freq = config.target_update_freq - self.target_entropy = -np.prod(self.actor_net.num_actions) + self.target_entropy = -self.actor_net.num_actions self.actor_net_optimiser = torch.optim.Adam( self.actor_net.parameters(), lr=config.actor_lr diff --git a/cares_reinforcement_learning/algorithm/policy/SACAE.py b/cares_reinforcement_learning/algorithm/policy/SACAE.py index fd6b7cc..121afed 100644 --- a/cares_reinforcement_learning/algorithm/policy/SACAE.py +++ b/cares_reinforcement_learning/algorithm/policy/SACAE.py @@ -61,7 +61,7 @@ def __init__( critic_beta = 0.9 alpha_beta = 0.5 - self.target_entropy = -np.prod(self.actor_net.num_actions) + self.target_entropy = -self.actor_net.num_actions self.actor_net_optimiser = torch.optim.Adam( self.actor_net.parameters(), lr=config.actor_lr, betas=(actor_beta, 0.999) diff --git a/cares_reinforcement_learning/algorithm/policy/TQC.py b/cares_reinforcement_learning/algorithm/policy/TQC.py index dab4a33..5b51d53 100644 --- a/cares_reinforcement_learning/algorithm/policy/TQC.py +++ b/cares_reinforcement_learning/algorithm/policy/TQC.py @@ -49,7 +49,7 @@ def __init__( self.device = device - self.target_entropy = -np.prod(self.actor_net.num_actions) + self.target_entropy = -self.actor_net.num_actions self.actor_net_optimiser = torch.optim.Adam( self.actor_net.parameters(), lr=config.actor_lr