diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 341e516bf..84860b354 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -18,8 +18,7 @@ # TODO: (1) better device management from collections import deque -from copy import deepcopy -from typing import Callable, Optional, Sequence, Tuple +from typing import Callable, Optional, Sequence, Tuple, Union import einops import numpy as np @@ -72,8 +71,8 @@ def __init__( encoder=encoder_critic, network=MLP( input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], - **config.critic_network_kwargs - ) + **config.critic_network_kwargs, + ), ) critic_nets.append(critic_net) @@ -83,8 +82,8 @@ def __init__( encoder=encoder_critic, network=MLP( input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], - **config.critic_network_kwargs - ) + **config.critic_network_kwargs, + ), ) target_critic_nets.append(target_critic_net) @@ -93,16 +92,13 @@ def __init__( self.actor = Policy( encoder=encoder_actor, - network=MLP( - input_dim=encoder_actor.output_dim, - **config.actor_network_kwargs - ), + network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), action_dim=config.output_shapes["action"][0], - **config.policy_kwargs + **config.policy_kwargs, ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_shapes["action"][0])/2 # (-dim(A)/2) - self.temperature = LagrangeMultiplier(init_value=config.temperature_init) + config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) + self.temperature = LagrangeMultiplier(init_value=config.temperature_init) def reset(self): """ @@ -125,15 +121,17 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: actions, _, _ = self.actor(batch) actions = self.unnormalize_outputs({"action": actions})["action"] return actions - - def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor: + + def critic_forward( + self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False + ) -> Tensor: """Forward pass through a critic network ensemble - + Args: observations: Dictionary of observations actions: Action tensor use_target: If True, use target critics, otherwise use ensemble critics - + Returns: Tensor of Q-values from all critics """ @@ -141,15 +139,14 @@ def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_t q_values = torch.stack([critic(observations, actions) for critic in critics]) return q_values - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: """Run the batch through the model and compute the loss. - + Returns a dictionary with loss as a tensor, and other information as native floats. """ batch = self.normalize_inputs(batch) - # batch shape is (b, 2, ...) where index 1 returns the current observation and - # the next observation for calculating the right td index. + # batch shape is (b, 2, ...) where index 1 returns the current observation and + # the next observation for calculating the right td index. actions = batch["action"][:, 0] rewards = batch["next.reward"][:, 0] observations = {} @@ -158,12 +155,12 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: if k.startswith("observation."): observations[k] = batch[k][:, 0] next_observations[k] = batch[k][:, 1] - + # perform image augmentation - # reward bias from HIL-SERL code base + # reward bias from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch - + # calculate critics loss # 1- compute actions from policy with torch.no_grad(): @@ -175,103 +172,105 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: # subsample critics to prevent overfitting if use high UTD (update to date) if self.config.num_subsample_critics is not None: indices = torch.randperm(self.config.num_critics) - indices = indices[:self.config.num_subsample_critics] + indices = indices[: self.config.num_subsample_critics] q_targets = q_targets[indices] # critics subsample size min_q, _ = q_targets.min(dim=0) # Get values from min operation # breakpoint() if self.config.use_backup_entropy: - min_q -= self.temperature() * log_probs * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] + min_q -= ( + self.temperature() + * log_probs + * ~batch["observation.state_is_pad"][:, 0] + * ~batch["action_is_pad"][:, 0] + ) # shape: [batch_size, horizon] td_target = rewards + self.config.discount * min_q * ~batch["next.done"] - # td_target -= self.config.discount * self.temperature() * log_probs \ - # * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] + # td_target -= self.config.discount * self.temperature() * log_probs \ + # * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] # print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}") - + # 3- compute predicted qs q_preds = self.critic_forward(observations, actions, use_target=False) # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - critics_loss = ( + critics_loss = ( F.mse_loss( - q_preds, - einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), - reduction="none", - ).sum(0) # sum over ensemble - # `q_preds_ensemble` depends on the first observation and the actions. - * ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1] - * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] - # q_targets depends on the reward and the next observations. - * ~batch["next.reward_is_pad"][:,0] # shape: [batch_size, horizon] - * ~batch["observation.state_is_pad"][:,1] # shape: [batch_size, horizon+1] - ).mean() - + q_preds, + einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), + reduction="none", + ).sum(0) # sum over ensemble + # `q_preds_ensemble` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1] + * ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon] + # q_targets depends on the reward and the next observations. + * ~batch["next.reward_is_pad"][:, 0] # shape: [batch_size, horizon] + * ~batch["observation.state_is_pad"][:, 1] # shape: [batch_size, horizon+1] + ).mean() + # calculate actors loss # 1- temperature temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) actions, log_probs, _ = self.actor(observations) # 3- get q-value predictions - # with torch.inference_mode(): - q_preds = self.critic_forward(observations, actions, use_target=False) + with torch.inference_mode(): + q_preds = self.critic_forward(observations, actions, use_target=False) # q_preds_min = torch.min(q_preds, axis=0) - min_q_preds = q_preds.min(dim=0)[0] + min_q_preds = q_preds.min(dim=0)[0] # print(f"Q-values stats: mean={min_q_preds.mean():.3f}, min={min_q_preds.min():.3f}, max={min_q_preds.max():.3f}") # print(f"Log probs stats: mean={log_probs.mean():.3f}, min={log_probs.min():.3f}, max={log_probs.max():.3f}") # breakpoint() actor_loss = ( -(min_q_preds - temperature * log_probs).mean() - * ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1] - * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] + * ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1] + * ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon] ).mean() - # calculate temperature loss # 1- calculate entropy with torch.no_grad(): actions, log_probs, _ = self.actor(observations) entropy = -log_probs.mean() - temperature_loss = self.temperature( - lhs=entropy, - rhs=self.config.target_entropy - ) + temperature_loss = self.temperature(lhs=entropy, rhs=self.config.target_entropy) loss = critics_loss + actor_loss + temperature_loss return { - "critics_loss": critics_loss.item(), - "actor_loss": actor_loss.item(), - "mean_q_predicts": min_q_preds.mean().item(), - "min_q_predicts":min_q_preds.min().item(), - "max_q_predicts":min_q_preds.max().item(), - "temperature_loss": temperature_loss.item(), - "temperature": temperature.item(), - "mean_log_probs": log_probs.mean().item(), - "min_log_probs": log_probs.min().item(), - "max_log_probs": log_probs.max().item(), - "td_target_mean": td_target.mean().item(), - "td_target_mean": td_target.max().item(), - "action_mean": actions.mean().item(), - "entropy": entropy.item(), - "loss": loss, - } - + "critics_loss": critics_loss.item(), + "actor_loss": actor_loss.item(), + "mean_q_predicts": min_q_preds.mean().item(), + "min_q_predicts": min_q_preds.min().item(), + "max_q_predicts": min_q_preds.max().item(), + "temperature_loss": temperature_loss.item(), + "temperature": temperature.item(), + "mean_log_probs": log_probs.mean().item(), + "min_log_probs": log_probs.min().item(), + "max_log_probs": log_probs.max().item(), + "td_target_mean": td_target.mean().item(), + "td_target_max": td_target.max().item(), + "action_mean": actions.mean().item(), + "entropy": entropy.item(), + "loss": loss, + } + def update(self): # TODO: implement UTD update # First update only critics for utd_ratio-1 times - #for critic_step in range(self.config.utd_ratio - 1): - # only update critic and critic target + # for critic_step in range(self.config.utd_ratio - 1): + # only update critic and critic target # Then update critic, critic target, actor and temperature """Update target networks with exponential moving average""" with torch.no_grad(): for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False): for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False): target_param.data.copy_( - param.data * self.config.critic_target_update_weight + - target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - + + class MLP(nn.Module): def __init__( self, @@ -284,52 +283,54 @@ def __init__( super().__init__() self.activate_final = activate_final layers = [] - + # First layer uses input_dim layers.append(nn.Linear(input_dim, hidden_dims[0])) - + # Add activation after first layer if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(hidden_dims[0])) layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) - + # Rest of the layers for i in range(1, len(hidden_dims)): - layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i])) - + layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i])) + if i + 1 < len(hidden_dims) or activate_final: if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(hidden_dims[i])) - layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) - + layers.append( + activations if isinstance(activations, nn.Module) else getattr(nn, activations)() + ) + self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) - - + + class Critic(nn.Module): def __init__( self, encoder: Optional[nn.Module], network: nn.Module, init_final: Optional[float] = None, - device: str = "cuda" + device: str = "cuda", ): super().__init__() self.device = torch.device(device) self.encoder = encoder self.network = network self.init_final = init_final - + # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): out_features = layer.out_features break - + # Output layer if init_final is not None: self.output_layer = nn.Linear(out_features, 1) @@ -338,22 +339,20 @@ def __init__( else: self.output_layer = nn.Linear(out_features, 1) orthogonal_init()(self.output_layer.weight) - + self.to(self.device) def forward( - self, - observations: dict[str, torch.Tensor], + self, + observations: dict[str, torch.Tensor], actions: torch.Tensor, ) -> torch.Tensor: # Move each tensor in observations to device - observations = { - k: v.to(self.device) for k, v in observations.items() - } + observations = {k: v.to(self.device) for k, v in observations.items()} actions = actions.to(self.device) - + obs_enc = observations if self.encoder is None else self.encoder(observations) - + inputs = torch.cat([obs_enc, actions], dim=-1) x = self.network(inputs) value = self.output_layer(x) @@ -371,7 +370,7 @@ def __init__( fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, use_tanh_squash: bool = False, - device: str = "cuda" + device: str = "cuda", ): super().__init__() self.device = torch.device(device) @@ -382,13 +381,13 @@ def __init__( self.log_std_max = log_std_max self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.use_tanh_squash = use_tanh_squash - + # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): out_features = layer.out_features break - + # Mean layer self.mean_layer = nn.Linear(out_features, action_dim) if init_final is not None: @@ -396,7 +395,7 @@ def __init__( nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) else: orthogonal_init()(self.mean_layer.weight) - + # Standard deviation layer or parameter if fixed_std is None: self.std_layer = nn.Linear(out_features, action_dim) @@ -405,43 +404,47 @@ def __init__( nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: orthogonal_init()(self.std_layer.weight) - + self.to(self.device) def forward( - self, + self, observations: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - # Encode observations if encoder exists obs_enc = observations if self.encoder is None else self.encoder(observations) # Get network outputs outputs = self.network(obs_enc) means = self.mean_layer(outputs) - + # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!" + if self.use_tanh_squash: log_std = torch.tanh(log_std) + log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0) + else: + log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) else: log_std = self.fixed_std.expand_as(means) - - # uses tahn activation function to squash the action to be in the range of [-1, 1] + + # uses tanh activation function to squash the action to be in the range of [-1, 1] normal = torch.distributions.Normal(means, torch.exp(log_std)) - x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) - x_t = torch.clamp(x_t, -2.0, 2.0) - log_probs = normal.log_prob(x_t) + x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1)) + log_probs = normal.log_prob(x_t) # Base log probability before Tanh + if self.use_tanh_squash: actions = torch.tanh(x_t) - log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) - log_probs = log_probs.sum(-1) # sum over action dim - means = torch.tanh(means) + log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh + else: + actions = x_t # No Tanh; raw Gaussian sample + log_probs = log_probs.sum(-1) # Sum over action dimensions return actions, log_probs, means - + def get_features(self, observations: torch.Tensor) -> torch.Tensor: """Get encoded features from observations""" observations = observations.to(self.device) @@ -495,9 +498,7 @@ def __init__(self, config: SACConfig): ) if "observation.environment_state" in config.input_shapes: self.env_state_enc_layers = nn.Sequential( - nn.Linear( - config.input_shapes["observation.environment_state"][0], config.latent_dim - ), + nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim), nn.LayerNorm(config.latent_dim), nn.Tanh(), ) @@ -519,7 +520,7 @@ def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: feat.append(self.state_enc_layers(obs_dict["observation.state"])) # TODO(ke-wang): currently average over all features, concatenate all features maybe a better way return torch.stack(feat, dim=0).mean(0) - + @property def output_dim(self) -> int: """Returns the dimension of the encoder output""" @@ -527,48 +528,47 @@ def output_dim(self) -> int: class LagrangeMultiplier(nn.Module): - def __init__( - self, - init_value: float = 1.0, - constraint_shape: Sequence[int] = (), - device: str = "cuda" - ): + def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"): super().__init__() self.device = torch.device(device) - # init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) - init_value = torch.tensor(init_value, device=self.device) - - # Initialize the Lagrange multiplier as a parameter - self.lagrange = nn.Parameter( - torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device) - ) - + # Parameterize log(alpha) directly to ensure positivity + log_alpha = torch.log(torch.tensor(init_value, dtype=torch.float32, device=self.device)) + self.log_alpha = nn.Parameter(torch.full(constraint_shape, log_alpha)) + def forward( - self, - lhs: Optional[torch.Tensor | float | int] = None, - rhs: Optional[torch.Tensor | float | int] = None + self, + lhs: Optional[Union[torch.Tensor, float, int]] = None, + rhs: Optional[Union[torch.Tensor, float, int]] = None, ) -> torch.Tensor: - # Get the multiplier value based on parameterization - # multiplier = torch.nn.functional.softplus(self.lagrange) - log_multiplier = torch.log(self.lagrange) + # Compute alpha = exp(log_alpha) + alpha = self.log_alpha.exp() - # Return the raw multiplier if no constraint values provided + # Return alpha directly if no constraints provided if lhs is None: - return log_multiplier.exp() - + return alpha + # Convert inputs to tensors and move to device - lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device) + lhs = ( + torch.tensor(lhs, device=self.device) + if not isinstance(lhs, torch.Tensor) + else lhs.to(self.device) + ) if rhs is not None: - rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device) + rhs = ( + torch.tensor(rhs, device=self.device) + if not isinstance(rhs, torch.Tensor) + else rhs.to(self.device) + ) else: rhs = torch.zeros_like(lhs, device=self.device) - + + # Compute the difference and apply the multiplier diff = lhs - rhs - - assert diff.shape == log_multiplier.shape, f"Shape mismatch: {diff.shape} vs {log_multiplier.shape}" - - return log_multiplier.exp() * diff # numerically better + + assert diff.shape == alpha.shape, f"Shape mismatch: {diff.shape} vs {alpha.shape}" + + return alpha * diff def orthogonal_init(): @@ -580,6 +580,7 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" return nn.ModuleList(critics).to(device) + # borrowed from tdmpc def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor. @@ -587,7 +588,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens Args: fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return (B, *), where * is any number of dimensions. - image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and can be more than 1 dimensions, generally different from *. Returns: A return value from the callable reshaped to (**, *). @@ -597,4 +598,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens start_dims = image_tensor.shape[:-3] inp = torch.flatten(image_tensor, end_dim=-4) flat_out = fn(inp) - return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) \ No newline at end of file + return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))