From de5c5f07293f4505345ba8e779eb66ea2f6fd37a Mon Sep 17 00:00:00 2001 From: beardyface Date: Tue, 8 Oct 2024 14:11:36 +1300 Subject: [PATCH] TD3AE vector + image --- .../algorithm/policy/SACAE.py | 7 +- .../algorithm/policy/TD3AE.py | 68 ++++++++++++++----- .../networks/SACAE/actor.py | 8 +-- .../networks/SACAE/critic.py | 8 +-- .../networks/TD3AE/actor.py | 18 +++-- .../networks/TD3AE/critic.py | 21 ++++-- .../util/configurations.py | 2 + .../util/network_factory.py | 28 ++++++-- 8 files changed, 116 insertions(+), 44 deletions(-) diff --git a/cares_reinforcement_learning/algorithm/policy/SACAE.py b/cares_reinforcement_learning/algorithm/policy/SACAE.py index 3e12514d..18f4bd4b 100644 --- a/cares_reinforcement_learning/algorithm/policy/SACAE.py +++ b/cares_reinforcement_learning/algorithm/policy/SACAE.py @@ -15,7 +15,6 @@ import torch.nn.functional as F import cares_reinforcement_learning.util.helpers as hlp -from cares_reinforcement_learning.encoders.configurations import VanillaAEConfig from cares_reinforcement_learning.encoders.losses import AELoss from cares_reinforcement_learning.memory import MemoryBuffer from cares_reinforcement_learning.util.configurations import SACAEConfig @@ -113,12 +112,12 @@ def select_action_from_policy( image_tensor = image_tensor.unsqueeze(0).to(self.device) image_tensor = image_tensor / 255 - state = {"image": image_tensor, "vector": vector_tensor} + state_tensor = {"image": image_tensor, "vector": vector_tensor} if evaluation: - (_, _, action) = self.actor_net(state) + (_, _, action) = self.actor_net(state_tensor) else: - (action, _, _) = self.actor_net(state) + (action, _, _) = self.actor_net(state_tensor) action = action.cpu().data.numpy().flatten() self.actor_net.train() return action diff --git a/cares_reinforcement_learning/algorithm/policy/TD3AE.py b/cares_reinforcement_learning/algorithm/policy/TD3AE.py index 10d5b544..2a05596d 100644 --- a/cares_reinforcement_learning/algorithm/policy/TD3AE.py +++ b/cares_reinforcement_learning/algorithm/policy/TD3AE.py @@ -13,7 +13,6 @@ import torch.nn.functional as F import cares_reinforcement_learning.util.helpers as hlp -from cares_reinforcement_learning.encoders.configurations import VanillaAEConfig from cares_reinforcement_learning.encoders.losses import AELoss from cares_reinforcement_learning.memory import MemoryBuffer from cares_reinforcement_learning.util.configurations import TD3AEConfig @@ -78,13 +77,21 @@ def __init__( ) def select_action_from_policy( - self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0.1 + self, + state: dict[str, np.ndarray], + evaluation: bool = False, + noise_scale: float = 0.1, ) -> np.ndarray: self.actor_net.eval() with torch.no_grad(): - state_tensor = torch.FloatTensor(state).to(self.device) - state_tensor = state_tensor.unsqueeze(0) - state_tensor = state_tensor / 255 + vector_tensor = torch.FloatTensor(state["vector"]) + vector_tensor = vector_tensor.unsqueeze(0).to(self.device) + + image_tensor = torch.FloatTensor(state["image"]) + image_tensor = image_tensor.unsqueeze(0).to(self.device) + image_tensor = image_tensor / 255 + + state_tensor = {"image": image_tensor, "vector": vector_tensor} action = self.actor_net(state_tensor) action = action.cpu().data.numpy().flatten() @@ -98,7 +105,7 @@ def select_action_from_policy( def _update_critic( self, - states: torch.Tensor, + states: dict[str, torch.Tensor], actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, @@ -132,7 +139,10 @@ def _update_critic( return critic_loss_one.item(), critic_loss_two.item(), critic_loss_total.item() - def _update_actor(self, states: torch.Tensor) -> float: + def _update_actor( + self, + states: dict[str, torch.Tensor], + ) -> float: actions = self.actor_net(states, detach_encoder=True) actor_q_values, _ = self.critic_net(states, actions, detach_encoder=True) actor_loss = -actor_q_values.mean() @@ -167,28 +177,50 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: experiences = memory.sample_uniform(batch_size) states, actions, rewards, next_states, dones, _ = experiences - batch_size = len(states) + states_images = [state["image"] for state in states] + states_vector = [state["vector"] for state in states] + + next_states_images = [next_state["image"] for next_state in next_states] + next_states_vector = [next_state["vector"] for next_state in next_states] + + batch_size = len(states_images) # Convert into tensor - states = torch.FloatTensor(np.asarray(states)).to(self.device) + states_images = torch.FloatTensor(np.asarray(states_images)).to(self.device) + states_vector = torch.FloatTensor(np.asarray(states_vector)).to(self.device) + + # Normalise states and next_states - image portion + # This because the states are [0-255] and the predictions are [0-1] + states_images = states_images / 255 + + states = {"image": states_images, "vector": states_vector} + actions = torch.FloatTensor(np.asarray(actions)).to(self.device) rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device) - next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device) + + next_states_images = torch.FloatTensor(np.asarray(next_states_images)).to( + self.device + ) + next_states_vector = torch.FloatTensor(np.asarray(next_states_vector)).to( + self.device + ) + + # Normalise states and next_states - image portion + # This because the states are [0-255] and the predictions are [0-1] + next_states_images = next_states_images / 255 + + next_states = {"image": next_states_images, "vector": next_states_vector} + dones = torch.LongTensor(np.asarray(dones)).to(self.device) # Reshape to batch_size rewards = rewards.unsqueeze(0).reshape(batch_size, 1) dones = dones.unsqueeze(0).reshape(batch_size, 1) - # Normalise states and next_states - # This because the states are [0-255] and the predictions are [0-1] - states_normalised = states / 255 - next_states_normalised = next_states / 255 - info = {} critic_loss_one, critic_loss_two, critic_loss_total = self._update_critic( - states_normalised, actions, rewards, next_states_normalised, dones + states, actions, rewards, next_states, dones ) info["critic_loss_one"] = critic_loss_one info["critic_loss_two"] = critic_loss_two @@ -196,7 +228,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: if self.learn_counter % self.policy_update_freq == 0: # Update Actor - actor_loss = self._update_actor(states_normalised) + actor_loss = self._update_actor(states) info["actor_loss"] = actor_loss # Update target network params @@ -222,7 +254,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: ) if self.learn_counter % self.decoder_update_freq == 0: - ae_loss = self._update_autoencoder(states_normalised) + ae_loss = self._update_autoencoder(states["image"]) info["ae_loss"] = ae_loss return info diff --git a/cares_reinforcement_learning/networks/SACAE/actor.py b/cares_reinforcement_learning/networks/SACAE/actor.py index a895e776..46928189 100644 --- a/cares_reinforcement_learning/networks/SACAE/actor.py +++ b/cares_reinforcement_learning/networks/SACAE/actor.py @@ -8,7 +8,7 @@ class Actor(SACActor): def __init__( self, - vector_observation: int, + vector_observation_size: int, encoder: Encoder, num_actions: int, hidden_size: list[int] = None, @@ -20,7 +20,7 @@ def __init__( log_std_bounds = [-10, 2] super().__init__( - encoder.latent_dim + vector_observation, + encoder.latent_dim + vector_observation_size, num_actions, hidden_size, log_std_bounds, @@ -28,7 +28,7 @@ def __init__( self.encoder = encoder - self.vector_observation = vector_observation + self.vector_observation_size = vector_observation_size self.apply(hlp.weight_init) @@ -39,7 +39,7 @@ def forward( state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) actor_input = state_latent - if self.vector_observation > 0: + if self.vector_observation_size > 0: actor_input = torch.cat([state["vector"], actor_input], dim=1) return super().forward(actor_input) diff --git a/cares_reinforcement_learning/networks/SACAE/critic.py b/cares_reinforcement_learning/networks/SACAE/critic.py index 837ad577..5d9e59c4 100644 --- a/cares_reinforcement_learning/networks/SACAE/critic.py +++ b/cares_reinforcement_learning/networks/SACAE/critic.py @@ -8,7 +8,7 @@ class Critic(SACCritic): def __init__( self, - vector_observation: int, + vector_observation_size: int, encoder: Encoder, num_actions: int, hidden_size: list[int] = None, @@ -17,10 +17,10 @@ def __init__( hidden_size = [1024, 1024] super().__init__( - encoder.latent_dim + vector_observation, num_actions, hidden_size + encoder.latent_dim + vector_observation_size, num_actions, hidden_size ) - self.vector_observation = vector_observation + self.vector_observation_size = vector_observation_size self.encoder = encoder @@ -36,7 +36,7 @@ def forward( state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) critic_input = state_latent - if self.vector_observation > 0: + if self.vector_observation_size > 0: critic_input = torch.cat([state["vector"], critic_input], dim=1) return super().forward(critic_input, action) diff --git a/cares_reinforcement_learning/networks/TD3AE/actor.py b/cares_reinforcement_learning/networks/TD3AE/actor.py index 24c9bb11..75ddf3af 100644 --- a/cares_reinforcement_learning/networks/TD3AE/actor.py +++ b/cares_reinforcement_learning/networks/TD3AE/actor.py @@ -8,6 +8,7 @@ class Actor(TD3Actor): def __init__( self, + vector_observation_size: int, encoder: Encoder, num_actions: int, hidden_size: list[int] = None, @@ -15,15 +16,24 @@ def __init__( if hidden_size is None: hidden_size = [1024, 1024] - super().__init__(encoder.latent_dim, num_actions, hidden_size) + super().__init__( + encoder.latent_dim + vector_observation_size, num_actions, hidden_size + ) self.encoder = encoder self.apply(hlp.weight_init) + self.vector_observation_size = vector_observation_size + def forward( - self, state: torch.Tensor, detach_encoder: bool = False + self, state: dict[str, torch.Tensor], detach_encoder: bool = False ) -> torch.Tensor: # Detach at the CNN layer to prevent backpropagation through the encoder - state_latent = self.encoder(state, detach_cnn=detach_encoder) - return super().forward(state_latent) + state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) + + actor_input = state_latent + if self.vector_observation_size > 0: + actor_input = torch.cat([state["vector"], actor_input], dim=1) + + return super().forward(actor_input) diff --git a/cares_reinforcement_learning/networks/TD3AE/critic.py b/cares_reinforcement_learning/networks/TD3AE/critic.py index f9e62ca2..7401def9 100644 --- a/cares_reinforcement_learning/networks/TD3AE/critic.py +++ b/cares_reinforcement_learning/networks/TD3AE/critic.py @@ -8,6 +8,7 @@ class Critic(TD3Critic): def __init__( self, + vector_observation_size: int, encoder: Encoder, num_actions: int, hidden_size: list[int] = None, @@ -15,15 +16,27 @@ def __init__( if hidden_size is None: hidden_size = [1024, 1024] - super().__init__(encoder.latent_dim, num_actions, hidden_size) + super().__init__( + encoder.latent_dim + vector_observation_size, num_actions, hidden_size + ) + + self.vector_observation_size = vector_observation_size self.encoder = encoder self.apply(hlp.weight_init) def forward( - self, state: torch.Tensor, action: torch.Tensor, detach_encoder: bool = False + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + detach_encoder: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: # Detach at the CNN layer to prevent backpropagation through the encoder - state_latent = self.encoder(state, detach_cnn=detach_encoder) - return super().forward(state_latent, action) + state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) + + critic_input = state_latent + if self.vector_observation_size > 0: + critic_input = torch.cat([state["vector"], critic_input], dim=1) + + return super().forward(critic_input, action) diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index 46c83990..d06ea459 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -156,6 +156,8 @@ class TD3AEConfig(AlgorithmConfig): encoder_tau: Optional[float] = 0.05 decoder_update_freq: Optional[int] = 1 + vector_observation: Optional[int] = 0 + autoencoder_config: Optional[VanillaAEConfig] = VanillaAEConfig( latent_dim=50, num_layers=4, diff --git a/cares_reinforcement_learning/util/network_factory.py b/cares_reinforcement_learning/util/network_factory.py index 37f46908..af3c726e 100644 --- a/cares_reinforcement_learning/util/network_factory.py +++ b/cares_reinforcement_learning/util/network_factory.py @@ -147,17 +147,19 @@ def create_SACAE(observation_size, action_num, config: AlgorithmConfig): actor_encoder = copy.deepcopy(autoencoder.encoder) critic_encoder = copy.deepcopy(autoencoder.encoder) - vector_observation = observation_size["vector"] if config.vector_observation else 0 + vector_observation_size = ( + observation_size["vector"] if config.vector_observation else 0 + ) actor = Actor( - vector_observation, + vector_observation_size, actor_encoder, action_num, hidden_size=config.hidden_size, log_std_bounds=config.log_std_bounds, ) critic = Critic( - vector_observation, + vector_observation_size, critic_encoder, action_num, hidden_size=config.hidden_size, @@ -232,14 +234,28 @@ def create_TD3AE(observation_size, action_num, config: AlgorithmConfig): ae_factory = AEFactory() autoencoder = ae_factory.create_autoencoder( - observation_size=observation_size, config=config.autoencoder_config + observation_size=observation_size["image"], config=config.autoencoder_config ) actor_encoder = copy.deepcopy(autoencoder.encoder) critic_encoder = copy.deepcopy(autoencoder.encoder) - actor = Actor(actor_encoder, action_num, hidden_size=config.hidden_size) - critic = Critic(critic_encoder, action_num, hidden_size=config.hidden_size) + vector_observation_size = ( + observation_size["vector"] if config.vector_observation else 0 + ) + + actor = Actor( + vector_observation_size, + actor_encoder, + action_num, + hidden_size=config.hidden_size, + ) + critic = Critic( + vector_observation_size, + critic_encoder, + action_num, + hidden_size=config.hidden_size, + ) device = hlp.get_device() agent = TD3AE(