From 13f85cf05047a678bf66f29dfc1b5c0dd4ddc41c Mon Sep 17 00:00:00 2001 From: beardyface Date: Tue, 5 Nov 2024 12:51:05 +1300 Subject: [PATCH] Handling target vs policy update parameters --- .../algorithm/mbrl/DynaSAC.py | 10 +++--- .../algorithm/policy/LA3PSAC.py | 4 +-- .../algorithm/policy/LAPSAC.py | 16 +++++---- .../algorithm/policy/MAPERSAC.py | 16 +++++---- .../algorithm/policy/PERSAC.py | 14 ++++---- .../algorithm/policy/RDSAC.py | 16 +++++---- .../algorithm/policy/REDQ.py | 16 +++++---- .../algorithm/policy/SAC.py | 16 +++++---- .../algorithm/policy/SACAE.py | 4 +-- .../algorithm/policy/SACD.py | 16 +++++---- .../algorithm/policy/TQC.py | 16 +++++---- .../algorithm/value/DQN.py | 2 +- .../networks/SACD/actor.py | 3 +- .../util/configurations.py | 34 ++++++++++++++++++- cares_reinforcement_learning/util/record.py | 6 ++-- 15 files changed, 120 insertions(+), 69 deletions(-) diff --git a/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py b/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py index 20a6614..76cd84c 100644 --- a/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py +++ b/cares_reinforcement_learning/algorithm/mbrl/DynaSAC.py @@ -48,7 +48,8 @@ def __init__( self.action_num = self.actor_net.num_actions self.learn_counter = 0 - self.policy_update_freq = 1 + self.policy_update_freq = config.policy_update_freq + self.target_update_freq = config.target_update_freq self.actor_net_optimiser = torch.optim.Adam( self.actor_net.parameters(), lr=config.actor_lr @@ -93,10 +94,11 @@ def _update_critic_actor(self, states, actions, rewards, next_states, dones): # Update Critic self._update_critic(states, actions, rewards, next_states, dones) - # Update Actor - self._update_actor(states) - if self.learn_counter % self.policy_update_freq == 0: + # Update Actor + self._update_actor(states) + + if self.learn_counter % self.target_update_freq == 0: hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau) def _update_critic(self, states, actions, rewards, next_states, dones): diff --git a/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py b/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py index 8a2c23a..935aced 100644 --- a/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py @@ -42,7 +42,7 @@ def __init__( self.prioritized_fraction = config.prioritized_fraction self.learn_counter = 0 - self.policy_update_freq = 1 + self.target_update_freq = config.target_update_freq self.target_entropy = -self.actor_net.num_actions @@ -190,7 +190,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: uniform_batch_size = int(batch_size * (1 - self.prioritized_fraction)) priority_batch_size = int(batch_size * self.prioritized_fraction) - policy_update = self.learn_counter % self.policy_update_freq == 0 + policy_update = self.learn_counter % self.target_update_freq == 0 ######################### UNIFORM SAMPLING ######################### experiences = memory.sample_uniform(uniform_batch_size) diff --git a/cares_reinforcement_learning/algorithm/policy/LAPSAC.py b/cares_reinforcement_learning/algorithm/policy/LAPSAC.py index 8edfc49..da307a8 100644 --- a/cares_reinforcement_learning/algorithm/policy/LAPSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/LAPSAC.py @@ -40,7 +40,8 @@ def __init__( self.min_priority = config.min_priority self.learn_counter = 0 - self.policy_update_freq = 1 + self.policy_update_freq = config.policy_update_freq + self.target_update_freq = config.target_update_freq self.target_entropy = -self.actor_net.num_actions @@ -178,13 +179,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: info["huber_lose_two"] = huber_lose_two info["critic_loss_total"] = critic_loss_total - # Update the Actor - actor_loss, alpha_loss = self._update_actor_alpha(states) - info["actor_loss"] = actor_loss - info["alpha_loss"] = alpha_loss - info["alpha"] = self.alpha.item() - if self.learn_counter % self.policy_update_freq == 0: + # Update the Actor + actor_loss, alpha_loss = self._update_actor_alpha(states) + info["actor_loss"] = actor_loss + info["alpha_loss"] = alpha_loss + info["alpha"] = self.alpha.item() + + if self.learn_counter % self.target_update_freq == 0: hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau) memory.update_priorities(indices, priorities) diff --git a/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py b/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py index cd40ea8..dd287ad 100644 --- a/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py @@ -43,7 +43,8 @@ def __init__( self.min_priority = config.min_priority self.learn_counter = 0 - self.policy_update_freq = 1 + self.policy_update_freq = config.policy_update_freq + self.target_update_freq = config.target_update_freq self.target_entropy = -self.actor_net.num_actions @@ -284,13 +285,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: ) info["critic_loss_total"] = critic_loss_total - # Update the Actor - actor_loss, alpha_loss = self._update_actor_alpha(states, weights) - info["actor_loss"] = actor_loss - info["alpha_loss"] = alpha_loss - info["alpha"] = self.alpha.item() - if self.learn_counter % self.policy_update_freq == 0: + # Update the Actor + actor_loss, alpha_loss = self._update_actor_alpha(states, weights) + info["actor_loss"] = actor_loss + info["alpha_loss"] = alpha_loss + info["alpha"] = self.alpha.item() + + if self.learn_counter % self.target_update_freq == 0: hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau) memory.update_priorities(indices, priorities) diff --git a/cares_reinforcement_learning/algorithm/policy/PERSAC.py b/cares_reinforcement_learning/algorithm/policy/PERSAC.py index 35466ae..714becc 100644 --- a/cares_reinforcement_learning/algorithm/policy/PERSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/PERSAC.py @@ -41,7 +41,8 @@ def __init__( self.min_priority = config.min_priority self.learn_counter = 0 - self.policy_update_freq = 1 + self.policy_update_freq = config.policy_update_freq + self.target_update_freq = config.target_update_freq self.target_entropy = -self.actor_net.num_actions @@ -175,12 +176,13 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: ) info["critic_loss"] = critic_loss_total - # Update the Actor - actor_loss = self._update_actor_alpha(states) - info["actor_loss"] = actor_loss - info["alpha"] = self.alpha.item() - if self.learn_counter % self.policy_update_freq == 0: + # Update the Actor + actor_loss = self._update_actor_alpha(states) + info["actor_loss"] = actor_loss + info["alpha"] = self.alpha.item() + + if self.learn_counter % self.target_update_freq == 0: hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau) memory.update_priorities(indices, priorities) diff --git a/cares_reinforcement_learning/algorithm/policy/RDSAC.py b/cares_reinforcement_learning/algorithm/policy/RDSAC.py index 1577ddf..73cee85 100644 --- a/cares_reinforcement_learning/algorithm/policy/RDSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/RDSAC.py @@ -35,7 +35,8 @@ def __init__( self.per_alpha = config.per_alpha self.learn_counter = 0 - self.policy_update_freq = 1 + self.policy_update_freq = config.policy_update_freq + self.target_update_freq = config.target_update_freq self.target_entropy = -self.actor_net.num_actions @@ -245,13 +246,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: ) info["critic_loss_total"] = critic_loss_total - # Update the Actor - actor_loss, alpha_loss = self._update_actor_alpha(states, weights) - info["actor_loss"] = actor_loss - info["alpha_loss"] = alpha_loss - info["alpha"] = self.alpha.item() - if self.learn_counter % self.policy_update_freq == 0: + # Update the Actor + actor_loss, alpha_loss = self._update_actor_alpha(states, weights) + info["actor_loss"] = actor_loss + info["alpha_loss"] = alpha_loss + info["alpha"] = self.alpha.item() + + if self.learn_counter % self.target_update_freq == 0: hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau) memory.update_priorities(indices, priorities) diff --git a/cares_reinforcement_learning/algorithm/policy/REDQ.py b/cares_reinforcement_learning/algorithm/policy/REDQ.py index 5700d90..95e3de9 100644 --- a/cares_reinforcement_learning/algorithm/policy/REDQ.py +++ b/cares_reinforcement_learning/algorithm/policy/REDQ.py @@ -29,7 +29,8 @@ def __init__( self.tau = config.tau self.learn_counter = 0 - self.policy_update_freq = 1 + self.policy_update_freq = config.policy_update_freq + self.target_update_freq = config.target_update_freq self.device = device @@ -193,13 +194,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: ) info["critic_loss_totals"] = critic_loss_totals - # Update the Actor - actor_loss, alpha_loss = self._update_actor_alpha(idx, states) - info["actor_loss"] = actor_loss - info["alpha_loss"] = alpha_loss - info["alpha"] = self.alpha.item() - if self.learn_counter % self.policy_update_freq == 0: + # Update the Actor + actor_loss, alpha_loss = self._update_actor_alpha(idx, states) + info["actor_loss"] = actor_loss + info["alpha_loss"] = alpha_loss + info["alpha"] = self.alpha.item() + + if self.learn_counter % self.target_update_freq == 0: # Update ensemble of target critics for critic_net, target_critic_net in zip( self.ensemble_critics, self.target_ensemble_critics diff --git a/cares_reinforcement_learning/algorithm/policy/SAC.py b/cares_reinforcement_learning/algorithm/policy/SAC.py index e4fa616..f185e64 100644 --- a/cares_reinforcement_learning/algorithm/policy/SAC.py +++ b/cares_reinforcement_learning/algorithm/policy/SAC.py @@ -42,7 +42,8 @@ def __init__( self.reward_scale = config.reward_scale self.learn_counter = 0 - self.policy_update_freq = 1 + self.policy_update_freq = config.policy_update_freq + self.target_update_freq = config.target_update_freq self.target_entropy = -self.actor_net.num_actions @@ -167,13 +168,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: info["critic_loss_two"] = critic_loss_two info["critic_loss"] = critic_loss_total - # Update the Actor and Alpha - actor_loss, alpha_loss = self._update_actor_alpha(states) - info["actor_loss"] = actor_loss - info["alpha_loss"] = alpha_loss - info["alpha"] = self.alpha.item() - if self.learn_counter % self.policy_update_freq == 0: + # Update the Actor and Alpha + actor_loss, alpha_loss = self._update_actor_alpha(states) + info["actor_loss"] = actor_loss + info["alpha_loss"] = alpha_loss + info["alpha"] = self.alpha.item() + + if self.learn_counter % self.target_update_freq == 0: hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau) return info diff --git a/cares_reinforcement_learning/algorithm/policy/SACAE.py b/cares_reinforcement_learning/algorithm/policy/SACAE.py index ae1d6ae..e184a70 100644 --- a/cares_reinforcement_learning/algorithm/policy/SACAE.py +++ b/cares_reinforcement_learning/algorithm/policy/SACAE.py @@ -54,8 +54,8 @@ def __init__( self.reward_scale = config.reward_scale self.learn_counter = 0 - self.policy_update_freq = 2 - self.target_update_freq = 2 + self.policy_update_freq = config.policy_update_freq + self.target_update_freq = config.target_update_freq actor_beta = 0.9 critic_beta = 0.9 diff --git a/cares_reinforcement_learning/algorithm/policy/SACD.py b/cares_reinforcement_learning/algorithm/policy/SACD.py index 9dfd630..320724b 100644 --- a/cares_reinforcement_learning/algorithm/policy/SACD.py +++ b/cares_reinforcement_learning/algorithm/policy/SACD.py @@ -42,7 +42,8 @@ def __init__( self.reward_scale = config.reward_scale self.learn_counter = 0 - self.policy_update_freq = 1 + self.policy_update_freq = config.policy_update_freq + self.target_update_freq = config.target_update_freq self.action_num = self.actor_net.num_actions @@ -182,13 +183,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: ) info["critic_loss"] = critic_loss_total - # Update the Actor and Alpha - actor_loss, alpha_loss = self._update_actor_alpha(states) - info["actor_loss"] = actor_loss - info["alpha_loss"] = alpha_loss - info["alpha"] = self.alpha.item() - if self.learn_counter % self.policy_update_freq == 0: + # Update the Actor and Alpha + actor_loss, alpha_loss = self._update_actor_alpha(states) + info["actor_loss"] = actor_loss + info["alpha_loss"] = alpha_loss + info["alpha"] = self.alpha.item() + + if self.learn_counter % self.target_update_freq == 0: hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau) return info diff --git a/cares_reinforcement_learning/algorithm/policy/TQC.py b/cares_reinforcement_learning/algorithm/policy/TQC.py index efbfbe6..5b51d53 100644 --- a/cares_reinforcement_learning/algorithm/policy/TQC.py +++ b/cares_reinforcement_learning/algorithm/policy/TQC.py @@ -44,7 +44,8 @@ def __init__( ) self.learn_counter = 0 - self.policy_update_freq = 1 + self.policy_update_freq = config.policy_update_freq + self.target_update_freq = config.target_update_freq self.device = device @@ -178,13 +179,14 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: ) info["critic_loss"] = critic_loss_total - # Update the Actor - actor_loss, alpha_loss = self._update_actor(states) - info["actor_loss"] = actor_loss - info["alpha_loss"] = alpha_loss - info["alpha"] = self.alpha.item() - if self.learn_counter % self.policy_update_freq == 0: + # Update the Actor + actor_loss, alpha_loss = self._update_actor(states) + info["actor_loss"] = actor_loss + info["alpha_loss"] = alpha_loss + info["alpha"] = self.alpha.item() + + if self.learn_counter % self.target_update_freq == 0: hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau) return info diff --git a/cares_reinforcement_learning/algorithm/value/DQN.py b/cares_reinforcement_learning/algorithm/value/DQN.py index 9a15b50..964ebb3 100644 --- a/cares_reinforcement_learning/algorithm/value/DQN.py +++ b/cares_reinforcement_learning/algorithm/value/DQN.py @@ -30,7 +30,7 @@ def __init__( self.network.parameters(), lr=config.lr ) - def select_action_from_policy(self, state): + def select_action_from_policy(self, state) -> int: self.network.eval() with torch.no_grad(): state_tensor = torch.FloatTensor(state).to(self.device) diff --git a/cares_reinforcement_learning/networks/SACD/actor.py b/cares_reinforcement_learning/networks/SACD/actor.py index b7a4bbe..03d08a2 100644 --- a/cares_reinforcement_learning/networks/SACD/actor.py +++ b/cares_reinforcement_learning/networks/SACD/actor.py @@ -30,11 +30,12 @@ def __init__( def forward( self, state: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor], torch.Tensor]: action_probs = self.act_net(state) max_probability_action = torch.argmax(action_probs) dist = torch.distributions.Categorical(action_probs) action = dist.sample() + # Offset any values which are zero by a small amount so no nan nonsense zero_offset = action_probs == 0.0 zero_offset = zero_offset.float() * 1e-8 diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index 9d44c53..2ce9439 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -180,6 +180,9 @@ class SACConfig(AlgorithmConfig): log_std_bounds: list[float] = [-20, 2] + policy_update_freq: int = 1 + target_update_freq: int = 1 + class SACAEConfig(AlgorithmConfig): algorithm: str = Field("SACAE", Literal=True) @@ -197,6 +200,9 @@ class SACAEConfig(AlgorithmConfig): log_std_bounds: list[float] = [-20, 2] + policy_update_freq: int = 2 + target_update_freq: int = 2 + encoder_tau: float = 0.05 decoder_update_freq: int = 1 @@ -230,6 +236,9 @@ class SACDConfig(AlgorithmConfig): tau: float = 0.005 reward_scale: float = 1.0 + policy_update_freq: int = 1 + target_update_freq: int = 1 + class DynaSACConfig(AlgorithmConfig): algorithm: str = Field("DynaSAC", Literal=True) @@ -247,6 +256,9 @@ class DynaSACConfig(AlgorithmConfig): log_std_bounds: list[float] = [-20, 2] + policy_update_freq: int = 1 + target_update_freq: int = 1 + horizon: int = 3 num_samples: int = 10 world_model_lr: float = 0.001 @@ -269,7 +281,7 @@ class NaSATD3Config(AlgorithmConfig): vector_observation: int = 0 - autoencoder_config: VanillaAEConfig | BurgessConfig = VanillaAEConfig( + autoencoder_config: VanillaAEConfig = VanillaAEConfig( latent_dim=200, num_layers=4, num_filters=32, @@ -302,6 +314,9 @@ class REDQConfig(AlgorithmConfig): G: int = 20 + policy_update_freq: int = 1 + target_update_freq: int = 1 + class TQCConfig(AlgorithmConfig): algorithm: str = Field("TQC", Literal=True) @@ -317,6 +332,9 @@ class TQCConfig(AlgorithmConfig): log_std_bounds: list[float] = [-20, 2] + policy_update_freq: int = 1 + target_update_freq: int = 1 + class CTD4Config(AlgorithmConfig): algorithm: str = Field("CTD4", Literal=True) @@ -361,6 +379,9 @@ class PERSACConfig(AlgorithmConfig): log_std_bounds: list[float] = [-20, 2] + policy_update_freq: int = 1 + target_update_freq: int = 1 + class LAPTD3Config(AlgorithmConfig): algorithm: str = Field("LAPTD3", Literal=True) @@ -390,6 +411,9 @@ class LAPSACConfig(AlgorithmConfig): log_std_bounds: list[float] = [-20, 2] + policy_update_freq: int = 1 + target_update_freq: int = 1 + class PALTD3Config(AlgorithmConfig): algorithm: str = Field("PALTD3", Literal=True) @@ -435,6 +459,8 @@ class LA3PSACConfig(AlgorithmConfig): log_std_bounds: list[float] = [-20, 2] + target_update_freq: int = 1 + class MAPERTD3Config(AlgorithmConfig): algorithm: str = Field("MAPERTD3", Literal=True) @@ -477,6 +503,9 @@ class MAPERSACConfig(AlgorithmConfig): hidden_size: list[int] = [400, 300] log_std_bounds: list[float] = [-20, 2] + policy_update_freq: int = 1 + target_update_freq: int = 1 + class RDTD3Config(AlgorithmConfig): algorithm: str = Field("RDTD3", Literal=True) @@ -504,3 +533,6 @@ class RDSACConfig(AlgorithmConfig): min_priority: float = 1.0 log_std_bounds: list[float] = [-20, 2] + + policy_update_freq: int = 1 + target_update_freq: int = 1 diff --git a/cares_reinforcement_learning/util/record.py b/cares_reinforcement_learning/util/record.py index 0f22696..ad65c26 100644 --- a/cares_reinforcement_learning/util/record.py +++ b/cares_reinforcement_learning/util/record.py @@ -31,8 +31,8 @@ def __init__( algorithm: str, task: str, plot_frequency: int = 10, - checkpoint_frequency: Optional[int] = None, - network: Optional[nn.Module] = None, + checkpoint_frequency: int | None = None, + network: nn.Module | None = None, ) -> None: self.best_reward = float("-inf") @@ -82,7 +82,7 @@ def __init__( self.log_count = 0 - self.video = None + self.video: cv2.VideoWriter = None self.__initialise_directories()