Skip to content

Commit

Permalink
SACAE image + vector
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Oct 8, 2024
1 parent 8358970 commit 7626747
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 30 deletions.
72 changes: 53 additions & 19 deletions cares_reinforcement_learning/algorithm/policy/SACAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,28 @@ def __init__(

# pylint: disable-next=unused-argument
def select_action_from_policy(
self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0
self,
state: dict[str, np.ndarray],
evaluation: bool = False,
noise_scale: float = 0,
) -> np.ndarray:
# note that when evaluating this algorithm we need to select mu as action
self.actor_net.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state)
state_tensor = state_tensor.unsqueeze(0).to(self.device)
state_tensor = state_tensor / 255

vector_tensor = torch.FloatTensor(state["vector"])
vector_tensor = vector_tensor.unsqueeze(0).to(self.device)

image_tensor = torch.FloatTensor(state["image"])
image_tensor = image_tensor.unsqueeze(0).to(self.device)
image_tensor = image_tensor / 255

state = {"image": image_tensor, "vector": vector_tensor}

if evaluation:
(_, _, action) = self.actor_net(state_tensor)
(_, _, action) = self.actor_net(state)
else:
(action, _, _) = self.actor_net(state_tensor)
(action, _, _) = self.actor_net(state)
action = action.cpu().data.numpy().flatten()
self.actor_net.train()
return action
Expand All @@ -120,12 +129,13 @@ def alpha(self) -> torch.Tensor:

def _update_critic(
self,
states: torch.Tensor,
states: dict[str, torch.Tensor],
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
dones: torch.Tensor,
) -> tuple[float, float, float]:

with torch.no_grad():
next_actions, next_log_pi, _ = self.actor_net(next_states)

Expand Down Expand Up @@ -153,7 +163,9 @@ def _update_critic(

return critic_loss_one.item(), critic_loss_two.item(), critic_loss_total.item()

def _update_actor_alpha(self, states: torch.Tensor) -> tuple[float, float]:
def _update_actor_alpha(
self, states: dict[str, torch.Tensor]
) -> tuple[float, float]:
pi, log_pi, _ = self.actor_net(states, detach_encoder=True)
qf1_pi, qf2_pi = self.critic_net(states, pi, detach_encoder=True)

Expand Down Expand Up @@ -197,37 +209,59 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
experiences = memory.sample_uniform(batch_size)
states, actions, rewards, next_states, dones, _ = experiences

batch_size = len(states)
states_images = [state["image"] for state in states]
states_vector = [state["vector"] for state in states]

next_states_images = [next_state["image"] for next_state in next_states]
next_states_vector = [next_state["vector"] for next_state in next_states]

batch_size = len(states_images)

# Convert into tensor
states = torch.FloatTensor(np.asarray(states)).to(self.device)
states_images = torch.FloatTensor(np.asarray(states_images)).to(self.device)
states_vector = torch.FloatTensor(np.asarray(states_vector)).to(self.device)

# Normalise states and next_states - image portion
# This because the states are [0-255] and the predictions are [0-1]
states_images = states_images / 255

states = {"image": states_images, "vector": states_vector}

actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device)
next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device)

next_states_images = torch.FloatTensor(np.asarray(next_states_images)).to(
self.device
)
next_states_vector = torch.FloatTensor(np.asarray(next_states_vector)).to(
self.device
)

# Normalise states and next_states - image portion
# This because the states are [0-255] and the predictions are [0-1]
next_states_images = next_states_images / 255

next_states = {"image": next_states_images, "vector": next_states_vector}

dones = torch.LongTensor(np.asarray(dones)).to(self.device)

# Reshape to batch_size x whatever
rewards = rewards.unsqueeze(0).reshape(batch_size, 1)
dones = dones.unsqueeze(0).reshape(batch_size, 1)

# Normalise states and next_states
# This because the states are [0-255] and the predictions are [0-1]
states_normalised = states / 255
next_states_normalised = next_states / 255

info = {}

# Update the Critic
critic_loss_one, critic_loss_two, critic_loss_total = self._update_critic(
states_normalised, actions, rewards, next_states_normalised, dones
states, actions, rewards, next_states, dones
)
info["critic_loss_one"] = critic_loss_one
info["critic_loss_two"] = critic_loss_two
info["critic_loss"] = critic_loss_total

# Update the Actor
if self.learn_counter % self.policy_update_freq == 0:
actor_loss, alpha_loss = self._update_actor_alpha(states_normalised)
actor_loss, alpha_loss = self._update_actor_alpha(states)
info["actor_loss"] = actor_loss
info["alpha_loss"] = alpha_loss
info["alpha"] = self.alpha.item()
Expand All @@ -247,7 +281,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
)

if self.learn_counter % self.decoder_update_freq == 0:
ae_loss = self._update_autoencoder(states_normalised)
ae_loss = self._update_autoencoder(states["image"])
info["ae_loss"] = ae_loss

return info
Expand Down
21 changes: 17 additions & 4 deletions cares_reinforcement_learning/networks/SACAE/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class Actor(SACActor):
def __init__(
self,
vector_observation: int,
encoder: Encoder,
num_actions: int,
hidden_size: list[int] = None,
Expand All @@ -18,15 +19,27 @@ def __init__(
if log_std_bounds is None:
log_std_bounds = [-10, 2]

super().__init__(encoder.latent_dim, num_actions, hidden_size, log_std_bounds)
super().__init__(
encoder.latent_dim + vector_observation,
num_actions,
hidden_size,
log_std_bounds,
)

self.encoder = encoder

self.vector_observation = vector_observation

self.apply(hlp.weight_init)

def forward(
self, state: torch.Tensor, detach_encoder: bool = False
self, state: dict[str, torch.Tensor], detach_encoder: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Detach at the CNN layer to prevent backpropagation through the encoder
state_latent = self.encoder(state, detach_cnn=detach_encoder)
return super().forward(state_latent)
state_latent = self.encoder(state["image"], detach_cnn=detach_encoder)

actor_input = state_latent
if self.vector_observation > 0:
actor_input = torch.cat([state["vector"], actor_input], dim=1)

return super().forward(actor_input)
21 changes: 17 additions & 4 deletions cares_reinforcement_learning/networks/SACAE/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,35 @@
class Critic(SACCritic):
def __init__(
self,
vector_observation: int,
encoder: Encoder,
num_actions: int,
hidden_size: list[int] = None,
):
if hidden_size is None:
hidden_size = [1024, 1024]

super().__init__(encoder.latent_dim, num_actions, hidden_size)
super().__init__(
encoder.latent_dim + vector_observation, num_actions, hidden_size
)

self.vector_observation = vector_observation

self.encoder = encoder

self.apply(hlp.weight_init)

def forward(
self, state: torch.Tensor, action: torch.Tensor, detach_encoder: bool = False
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
detach_encoder: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# Detach at the CNN layer to prevent backpropagation through the encoder
state_latent = self.encoder(state, detach_cnn=detach_encoder)
return super().forward(state_latent, action)
state_latent = self.encoder(state["image"], detach_cnn=detach_encoder)

critic_input = state_latent
if self.vector_observation > 0:
critic_input = torch.cat([state["vector"], critic_input], dim=1)

return super().forward(critic_input, action)
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class SACAEConfig(AlgorithmConfig):
encoder_tau: Optional[float] = 0.05
decoder_update_freq: Optional[int] = 1

include_vector_observation: Optional[int] = 0
vector_observation: Optional[int] = 0

autoencoder_config: Optional[VanillaAEConfig] = VanillaAEConfig(
latent_dim=50,
Expand Down
12 changes: 10 additions & 2 deletions cares_reinforcement_learning/util/network_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,27 @@ def create_SACAE(observation_size, action_num, config: AlgorithmConfig):

ae_factory = AEFactory()
autoencoder = ae_factory.create_autoencoder(
observation_size=observation_size, config=config.autoencoder_config
observation_size=observation_size["image"], config=config.autoencoder_config
)

actor_encoder = copy.deepcopy(autoencoder.encoder)
critic_encoder = copy.deepcopy(autoencoder.encoder)

vector_observation = observation_size["vector"] if config.vector_observation else 0

actor = Actor(
vector_observation,
actor_encoder,
action_num,
hidden_size=config.hidden_size,
log_std_bounds=config.log_std_bounds,
)
critic = Critic(critic_encoder, action_num, hidden_size=config.hidden_size)
critic = Critic(
vector_observation,
critic_encoder,
action_num,
hidden_size=config.hidden_size,
)

device = hlp.get_device()
agent = SACAE(
Expand Down

0 comments on commit 7626747

Please sign in to comment.