diff --git a/cares_reinforcement_learning/algorithm/policy/SACAE.py b/cares_reinforcement_learning/algorithm/policy/SACAE.py index a1081ed8..3e12514d 100644 --- a/cares_reinforcement_learning/algorithm/policy/SACAE.py +++ b/cares_reinforcement_learning/algorithm/policy/SACAE.py @@ -97,19 +97,28 @@ def __init__( # pylint: disable-next=unused-argument def select_action_from_policy( - self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0 + self, + state: dict[str, np.ndarray], + evaluation: bool = False, + noise_scale: float = 0, ) -> np.ndarray: # note that when evaluating this algorithm we need to select mu as action self.actor_net.eval() with torch.no_grad(): - state_tensor = torch.FloatTensor(state) - state_tensor = state_tensor.unsqueeze(0).to(self.device) - state_tensor = state_tensor / 255 + + vector_tensor = torch.FloatTensor(state["vector"]) + vector_tensor = vector_tensor.unsqueeze(0).to(self.device) + + image_tensor = torch.FloatTensor(state["image"]) + image_tensor = image_tensor.unsqueeze(0).to(self.device) + image_tensor = image_tensor / 255 + + state = {"image": image_tensor, "vector": vector_tensor} if evaluation: - (_, _, action) = self.actor_net(state_tensor) + (_, _, action) = self.actor_net(state) else: - (action, _, _) = self.actor_net(state_tensor) + (action, _, _) = self.actor_net(state) action = action.cpu().data.numpy().flatten() self.actor_net.train() return action @@ -120,12 +129,13 @@ def alpha(self) -> torch.Tensor: def _update_critic( self, - states: torch.Tensor, + states: dict[str, torch.Tensor], actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, dones: torch.Tensor, ) -> tuple[float, float, float]: + with torch.no_grad(): next_actions, next_log_pi, _ = self.actor_net(next_states) @@ -153,7 +163,9 @@ def _update_critic( return critic_loss_one.item(), critic_loss_two.item(), critic_loss_total.item() - def _update_actor_alpha(self, states: torch.Tensor) -> tuple[float, float]: + def _update_actor_alpha( + self, states: dict[str, torch.Tensor] + ) -> tuple[float, float]: pi, log_pi, _ = self.actor_net(states, detach_encoder=True) qf1_pi, qf2_pi = self.critic_net(states, pi, detach_encoder=True) @@ -197,29 +209,51 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: experiences = memory.sample_uniform(batch_size) states, actions, rewards, next_states, dones, _ = experiences - batch_size = len(states) + states_images = [state["image"] for state in states] + states_vector = [state["vector"] for state in states] + + next_states_images = [next_state["image"] for next_state in next_states] + next_states_vector = [next_state["vector"] for next_state in next_states] + + batch_size = len(states_images) # Convert into tensor - states = torch.FloatTensor(np.asarray(states)).to(self.device) + states_images = torch.FloatTensor(np.asarray(states_images)).to(self.device) + states_vector = torch.FloatTensor(np.asarray(states_vector)).to(self.device) + + # Normalise states and next_states - image portion + # This because the states are [0-255] and the predictions are [0-1] + states_images = states_images / 255 + + states = {"image": states_images, "vector": states_vector} + actions = torch.FloatTensor(np.asarray(actions)).to(self.device) rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device) - next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device) + + next_states_images = torch.FloatTensor(np.asarray(next_states_images)).to( + self.device + ) + next_states_vector = torch.FloatTensor(np.asarray(next_states_vector)).to( + self.device + ) + + # Normalise states and next_states - image portion + # This because the states are [0-255] and the predictions are [0-1] + next_states_images = next_states_images / 255 + + next_states = {"image": next_states_images, "vector": next_states_vector} + dones = torch.LongTensor(np.asarray(dones)).to(self.device) # Reshape to batch_size x whatever rewards = rewards.unsqueeze(0).reshape(batch_size, 1) dones = dones.unsqueeze(0).reshape(batch_size, 1) - # Normalise states and next_states - # This because the states are [0-255] and the predictions are [0-1] - states_normalised = states / 255 - next_states_normalised = next_states / 255 - info = {} # Update the Critic critic_loss_one, critic_loss_two, critic_loss_total = self._update_critic( - states_normalised, actions, rewards, next_states_normalised, dones + states, actions, rewards, next_states, dones ) info["critic_loss_one"] = critic_loss_one info["critic_loss_two"] = critic_loss_two @@ -227,7 +261,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: # Update the Actor if self.learn_counter % self.policy_update_freq == 0: - actor_loss, alpha_loss = self._update_actor_alpha(states_normalised) + actor_loss, alpha_loss = self._update_actor_alpha(states) info["actor_loss"] = actor_loss info["alpha_loss"] = alpha_loss info["alpha"] = self.alpha.item() @@ -247,7 +281,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]: ) if self.learn_counter % self.decoder_update_freq == 0: - ae_loss = self._update_autoencoder(states_normalised) + ae_loss = self._update_autoencoder(states["image"]) info["ae_loss"] = ae_loss return info diff --git a/cares_reinforcement_learning/networks/SACAE/actor.py b/cares_reinforcement_learning/networks/SACAE/actor.py index a9180442..a895e776 100644 --- a/cares_reinforcement_learning/networks/SACAE/actor.py +++ b/cares_reinforcement_learning/networks/SACAE/actor.py @@ -8,6 +8,7 @@ class Actor(SACActor): def __init__( self, + vector_observation: int, encoder: Encoder, num_actions: int, hidden_size: list[int] = None, @@ -18,15 +19,27 @@ def __init__( if log_std_bounds is None: log_std_bounds = [-10, 2] - super().__init__(encoder.latent_dim, num_actions, hidden_size, log_std_bounds) + super().__init__( + encoder.latent_dim + vector_observation, + num_actions, + hidden_size, + log_std_bounds, + ) self.encoder = encoder + self.vector_observation = vector_observation + self.apply(hlp.weight_init) def forward( - self, state: torch.Tensor, detach_encoder: bool = False + self, state: dict[str, torch.Tensor], detach_encoder: bool = False ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Detach at the CNN layer to prevent backpropagation through the encoder - state_latent = self.encoder(state, detach_cnn=detach_encoder) - return super().forward(state_latent) + state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) + + actor_input = state_latent + if self.vector_observation > 0: + actor_input = torch.cat([state["vector"], actor_input], dim=1) + + return super().forward(actor_input) diff --git a/cares_reinforcement_learning/networks/SACAE/critic.py b/cares_reinforcement_learning/networks/SACAE/critic.py index f42358b3..837ad577 100644 --- a/cares_reinforcement_learning/networks/SACAE/critic.py +++ b/cares_reinforcement_learning/networks/SACAE/critic.py @@ -8,6 +8,7 @@ class Critic(SACCritic): def __init__( self, + vector_observation: int, encoder: Encoder, num_actions: int, hidden_size: list[int] = None, @@ -15,15 +16,27 @@ def __init__( if hidden_size is None: hidden_size = [1024, 1024] - super().__init__(encoder.latent_dim, num_actions, hidden_size) + super().__init__( + encoder.latent_dim + vector_observation, num_actions, hidden_size + ) + + self.vector_observation = vector_observation self.encoder = encoder self.apply(hlp.weight_init) def forward( - self, state: torch.Tensor, action: torch.Tensor, detach_encoder: bool = False + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + detach_encoder: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: # Detach at the CNN layer to prevent backpropagation through the encoder - state_latent = self.encoder(state, detach_cnn=detach_encoder) - return super().forward(state_latent, action) + state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) + + critic_input = state_latent + if self.vector_observation > 0: + critic_input = torch.cat([state["vector"], critic_input], dim=1) + + return super().forward(critic_input, action) diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index 8d47343d..46c83990 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -199,7 +199,7 @@ class SACAEConfig(AlgorithmConfig): encoder_tau: Optional[float] = 0.05 decoder_update_freq: Optional[int] = 1 - include_vector_observation: Optional[int] = 0 + vector_observation: Optional[int] = 0 autoencoder_config: Optional[VanillaAEConfig] = VanillaAEConfig( latent_dim=50, diff --git a/cares_reinforcement_learning/util/network_factory.py b/cares_reinforcement_learning/util/network_factory.py index 6c42aaba..37f46908 100644 --- a/cares_reinforcement_learning/util/network_factory.py +++ b/cares_reinforcement_learning/util/network_factory.py @@ -141,19 +141,27 @@ def create_SACAE(observation_size, action_num, config: AlgorithmConfig): ae_factory = AEFactory() autoencoder = ae_factory.create_autoencoder( - observation_size=observation_size, config=config.autoencoder_config + observation_size=observation_size["image"], config=config.autoencoder_config ) actor_encoder = copy.deepcopy(autoencoder.encoder) critic_encoder = copy.deepcopy(autoencoder.encoder) + vector_observation = observation_size["vector"] if config.vector_observation else 0 + actor = Actor( + vector_observation, actor_encoder, action_num, hidden_size=config.hidden_size, log_std_bounds=config.log_std_bounds, ) - critic = Critic(critic_encoder, action_num, hidden_size=config.hidden_size) + critic = Critic( + vector_observation, + critic_encoder, + action_num, + hidden_size=config.hidden_size, + ) device = hlp.get_device() agent = SACAE(