Skip to content

Commit

Permalink
TD3AE vector + image
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Oct 8, 2024
1 parent 7626747 commit de5c5f0
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 44 deletions.
7 changes: 3 additions & 4 deletions cares_reinforcement_learning/algorithm/policy/SACAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch.nn.functional as F

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.encoders.configurations import VanillaAEConfig
from cares_reinforcement_learning.encoders.losses import AELoss
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.util.configurations import SACAEConfig
Expand Down Expand Up @@ -113,12 +112,12 @@ def select_action_from_policy(
image_tensor = image_tensor.unsqueeze(0).to(self.device)
image_tensor = image_tensor / 255

state = {"image": image_tensor, "vector": vector_tensor}
state_tensor = {"image": image_tensor, "vector": vector_tensor}

if evaluation:
(_, _, action) = self.actor_net(state)
(_, _, action) = self.actor_net(state_tensor)
else:
(action, _, _) = self.actor_net(state)
(action, _, _) = self.actor_net(state_tensor)
action = action.cpu().data.numpy().flatten()
self.actor_net.train()
return action
Expand Down
68 changes: 50 additions & 18 deletions cares_reinforcement_learning/algorithm/policy/TD3AE.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch.nn.functional as F

import cares_reinforcement_learning.util.helpers as hlp
from cares_reinforcement_learning.encoders.configurations import VanillaAEConfig
from cares_reinforcement_learning.encoders.losses import AELoss
from cares_reinforcement_learning.memory import MemoryBuffer
from cares_reinforcement_learning.util.configurations import TD3AEConfig
Expand Down Expand Up @@ -78,13 +77,21 @@ def __init__(
)

def select_action_from_policy(
self, state: np.ndarray, evaluation: bool = False, noise_scale: float = 0.1
self,
state: dict[str, np.ndarray],
evaluation: bool = False,
noise_scale: float = 0.1,
) -> np.ndarray:
self.actor_net.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
state_tensor = state_tensor.unsqueeze(0)
state_tensor = state_tensor / 255
vector_tensor = torch.FloatTensor(state["vector"])
vector_tensor = vector_tensor.unsqueeze(0).to(self.device)

image_tensor = torch.FloatTensor(state["image"])
image_tensor = image_tensor.unsqueeze(0).to(self.device)
image_tensor = image_tensor / 255

state_tensor = {"image": image_tensor, "vector": vector_tensor}

action = self.actor_net(state_tensor)
action = action.cpu().data.numpy().flatten()
Expand All @@ -98,7 +105,7 @@ def select_action_from_policy(

def _update_critic(
self,
states: torch.Tensor,
states: dict[str, torch.Tensor],
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
Expand Down Expand Up @@ -132,7 +139,10 @@ def _update_critic(

return critic_loss_one.item(), critic_loss_two.item(), critic_loss_total.item()

def _update_actor(self, states: torch.Tensor) -> float:
def _update_actor(
self,
states: dict[str, torch.Tensor],
) -> float:
actions = self.actor_net(states, detach_encoder=True)
actor_q_values, _ = self.critic_net(states, actions, detach_encoder=True)
actor_loss = -actor_q_values.mean()
Expand Down Expand Up @@ -167,36 +177,58 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
experiences = memory.sample_uniform(batch_size)
states, actions, rewards, next_states, dones, _ = experiences

batch_size = len(states)
states_images = [state["image"] for state in states]
states_vector = [state["vector"] for state in states]

next_states_images = [next_state["image"] for next_state in next_states]
next_states_vector = [next_state["vector"] for next_state in next_states]

batch_size = len(states_images)

# Convert into tensor
states = torch.FloatTensor(np.asarray(states)).to(self.device)
states_images = torch.FloatTensor(np.asarray(states_images)).to(self.device)
states_vector = torch.FloatTensor(np.asarray(states_vector)).to(self.device)

# Normalise states and next_states - image portion
# This because the states are [0-255] and the predictions are [0-1]
states_images = states_images / 255

states = {"image": states_images, "vector": states_vector}

actions = torch.FloatTensor(np.asarray(actions)).to(self.device)
rewards = torch.FloatTensor(np.asarray(rewards)).to(self.device)
next_states = torch.FloatTensor(np.asarray(next_states)).to(self.device)

next_states_images = torch.FloatTensor(np.asarray(next_states_images)).to(
self.device
)
next_states_vector = torch.FloatTensor(np.asarray(next_states_vector)).to(
self.device
)

# Normalise states and next_states - image portion
# This because the states are [0-255] and the predictions are [0-1]
next_states_images = next_states_images / 255

next_states = {"image": next_states_images, "vector": next_states_vector}

dones = torch.LongTensor(np.asarray(dones)).to(self.device)

# Reshape to batch_size
rewards = rewards.unsqueeze(0).reshape(batch_size, 1)
dones = dones.unsqueeze(0).reshape(batch_size, 1)

# Normalise states and next_states
# This because the states are [0-255] and the predictions are [0-1]
states_normalised = states / 255
next_states_normalised = next_states / 255

info = {}

critic_loss_one, critic_loss_two, critic_loss_total = self._update_critic(
states_normalised, actions, rewards, next_states_normalised, dones
states, actions, rewards, next_states, dones
)
info["critic_loss_one"] = critic_loss_one
info["critic_loss_two"] = critic_loss_two
info["critic_loss"] = critic_loss_total

if self.learn_counter % self.policy_update_freq == 0:
# Update Actor
actor_loss = self._update_actor(states_normalised)
actor_loss = self._update_actor(states)
info["actor_loss"] = actor_loss

# Update target network params
Expand All @@ -222,7 +254,7 @@ def train_policy(self, memory: MemoryBuffer, batch_size: int) -> dict[str, Any]:
)

if self.learn_counter % self.decoder_update_freq == 0:
ae_loss = self._update_autoencoder(states_normalised)
ae_loss = self._update_autoencoder(states["image"])
info["ae_loss"] = ae_loss

return info
Expand Down
8 changes: 4 additions & 4 deletions cares_reinforcement_learning/networks/SACAE/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class Actor(SACActor):
def __init__(
self,
vector_observation: int,
vector_observation_size: int,
encoder: Encoder,
num_actions: int,
hidden_size: list[int] = None,
Expand All @@ -20,15 +20,15 @@ def __init__(
log_std_bounds = [-10, 2]

super().__init__(
encoder.latent_dim + vector_observation,
encoder.latent_dim + vector_observation_size,
num_actions,
hidden_size,
log_std_bounds,
)

self.encoder = encoder

self.vector_observation = vector_observation
self.vector_observation_size = vector_observation_size

self.apply(hlp.weight_init)

Expand All @@ -39,7 +39,7 @@ def forward(
state_latent = self.encoder(state["image"], detach_cnn=detach_encoder)

actor_input = state_latent
if self.vector_observation > 0:
if self.vector_observation_size > 0:
actor_input = torch.cat([state["vector"], actor_input], dim=1)

return super().forward(actor_input)
8 changes: 4 additions & 4 deletions cares_reinforcement_learning/networks/SACAE/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class Critic(SACCritic):
def __init__(
self,
vector_observation: int,
vector_observation_size: int,
encoder: Encoder,
num_actions: int,
hidden_size: list[int] = None,
Expand All @@ -17,10 +17,10 @@ def __init__(
hidden_size = [1024, 1024]

super().__init__(
encoder.latent_dim + vector_observation, num_actions, hidden_size
encoder.latent_dim + vector_observation_size, num_actions, hidden_size
)

self.vector_observation = vector_observation
self.vector_observation_size = vector_observation_size

self.encoder = encoder

Expand All @@ -36,7 +36,7 @@ def forward(
state_latent = self.encoder(state["image"], detach_cnn=detach_encoder)

critic_input = state_latent
if self.vector_observation > 0:
if self.vector_observation_size > 0:
critic_input = torch.cat([state["vector"], critic_input], dim=1)

return super().forward(critic_input, action)
18 changes: 14 additions & 4 deletions cares_reinforcement_learning/networks/TD3AE/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,32 @@
class Actor(TD3Actor):
def __init__(
self,
vector_observation_size: int,
encoder: Encoder,
num_actions: int,
hidden_size: list[int] = None,
):
if hidden_size is None:
hidden_size = [1024, 1024]

super().__init__(encoder.latent_dim, num_actions, hidden_size)
super().__init__(
encoder.latent_dim + vector_observation_size, num_actions, hidden_size
)

self.encoder = encoder

self.apply(hlp.weight_init)

self.vector_observation_size = vector_observation_size

def forward(
self, state: torch.Tensor, detach_encoder: bool = False
self, state: dict[str, torch.Tensor], detach_encoder: bool = False
) -> torch.Tensor:
# Detach at the CNN layer to prevent backpropagation through the encoder
state_latent = self.encoder(state, detach_cnn=detach_encoder)
return super().forward(state_latent)
state_latent = self.encoder(state["image"], detach_cnn=detach_encoder)

actor_input = state_latent
if self.vector_observation_size > 0:
actor_input = torch.cat([state["vector"], actor_input], dim=1)

return super().forward(actor_input)
21 changes: 17 additions & 4 deletions cares_reinforcement_learning/networks/TD3AE/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,35 @@
class Critic(TD3Critic):
def __init__(
self,
vector_observation_size: int,
encoder: Encoder,
num_actions: int,
hidden_size: list[int] = None,
):
if hidden_size is None:
hidden_size = [1024, 1024]

super().__init__(encoder.latent_dim, num_actions, hidden_size)
super().__init__(
encoder.latent_dim + vector_observation_size, num_actions, hidden_size
)

self.vector_observation_size = vector_observation_size

self.encoder = encoder

self.apply(hlp.weight_init)

def forward(
self, state: torch.Tensor, action: torch.Tensor, detach_encoder: bool = False
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
detach_encoder: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# Detach at the CNN layer to prevent backpropagation through the encoder
state_latent = self.encoder(state, detach_cnn=detach_encoder)
return super().forward(state_latent, action)
state_latent = self.encoder(state["image"], detach_cnn=detach_encoder)

critic_input = state_latent
if self.vector_observation_size > 0:
critic_input = torch.cat([state["vector"], critic_input], dim=1)

return super().forward(critic_input, action)
2 changes: 2 additions & 0 deletions cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ class TD3AEConfig(AlgorithmConfig):
encoder_tau: Optional[float] = 0.05
decoder_update_freq: Optional[int] = 1

vector_observation: Optional[int] = 0

autoencoder_config: Optional[VanillaAEConfig] = VanillaAEConfig(
latent_dim=50,
num_layers=4,
Expand Down
28 changes: 22 additions & 6 deletions cares_reinforcement_learning/util/network_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,19 @@ def create_SACAE(observation_size, action_num, config: AlgorithmConfig):
actor_encoder = copy.deepcopy(autoencoder.encoder)
critic_encoder = copy.deepcopy(autoencoder.encoder)

vector_observation = observation_size["vector"] if config.vector_observation else 0
vector_observation_size = (
observation_size["vector"] if config.vector_observation else 0
)

actor = Actor(
vector_observation,
vector_observation_size,
actor_encoder,
action_num,
hidden_size=config.hidden_size,
log_std_bounds=config.log_std_bounds,
)
critic = Critic(
vector_observation,
vector_observation_size,
critic_encoder,
action_num,
hidden_size=config.hidden_size,
Expand Down Expand Up @@ -232,14 +234,28 @@ def create_TD3AE(observation_size, action_num, config: AlgorithmConfig):

ae_factory = AEFactory()
autoencoder = ae_factory.create_autoencoder(
observation_size=observation_size, config=config.autoencoder_config
observation_size=observation_size["image"], config=config.autoencoder_config
)

actor_encoder = copy.deepcopy(autoencoder.encoder)
critic_encoder = copy.deepcopy(autoencoder.encoder)

actor = Actor(actor_encoder, action_num, hidden_size=config.hidden_size)
critic = Critic(critic_encoder, action_num, hidden_size=config.hidden_size)
vector_observation_size = (
observation_size["vector"] if config.vector_observation else 0
)

actor = Actor(
vector_observation_size,
actor_encoder,
action_num,
hidden_size=config.hidden_size,
)
critic = Critic(
vector_observation_size,
critic_encoder,
action_num,
hidden_size=config.hidden_size,
)

device = hlp.get_device()
agent = TD3AE(
Expand Down

0 comments on commit de5c5f0

Please sign in to comment.