diff --git a/algorithms/any_percent_bc.py b/algorithms/any_percent_bc.py index 2bb00cf..0ce05a8 100644 --- a/algorithms/any_percent_bc.py +++ b/algorithms/any_percent_bc.py @@ -56,17 +56,11 @@ def __post_init__(self): def soft_update(target: nn.Module, source: nn.Module, tau: float): - for target_param, source_param in zip( - target.parameters(), source.parameters() - ): - target_param.data.copy_( - (1 - tau) * target_param.data + tau * source_param.data - ) + for target_param, source_param in zip(target.parameters(), source.parameters()): + target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data) -def compute_mean_std( - states: np.ndarray, eps: float -) -> Tuple[np.ndarray, np.ndarray]: +def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]: mean = states.mean(0) std = states.std(0) + eps return mean, std @@ -116,15 +110,11 @@ def __init__( self._actions = torch.zeros( (buffer_size, action_dim), dtype=torch.float32, device=device ) - self._rewards = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._next_states = torch.zeros( (buffer_size, state_dim), dtype=torch.float32, device=device ) - self._dones = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._device = device def _to_tensor(self, data: np.ndarray) -> torch.Tensor: @@ -133,9 +123,7 @@ def _to_tensor(self, data: np.ndarray) -> torch.Tensor: # Loads data in d4rl format, i.e. from Dict[str, np.array]. def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): if self._size != 0: - raise ValueError( - "Trying to load data into non-empty replay buffer" - ) + raise ValueError("Trying to load data into non-empty replay buffer") n_transitions = data["observations"].shape[0] if n_transitions > self._buffer_size: raise ValueError( @@ -143,24 +131,16 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): ) self._states[:n_transitions] = self._to_tensor(data["observations"]) self._actions[:n_transitions] = self._to_tensor(data["actions"]) - self._rewards[:n_transitions] = self._to_tensor( - data["rewards"][..., None] - ) - self._next_states[:n_transitions] = self._to_tensor( - data["next_observations"] - ) - self._dones[:n_transitions] = self._to_tensor( - data["terminals"][..., None] - ) + self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None]) + self._next_states[:n_transitions] = self._to_tensor(data["next_observations"]) + self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None]) self._size += n_transitions self._pointer = min(self._size, n_transitions) print(f"Dataset size: {n_transitions}") def sample(self, batch_size: int) -> TensorBatch: - indices = np.random.randint( - 0, min(self._size, self._pointer), size=batch_size - ) + indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size) states = self._states[indices] actions = self._actions[indices] rewards = self._rewards[indices] @@ -248,9 +228,7 @@ def keep_best_trajectories( cur_ids = [] cur_return = 0 reward_scale = 1.0 - for i, (reward, done) in enumerate( - zip(dataset["rewards"], dataset["terminals"]) - ): + for i, (reward, done) in enumerate(zip(dataset["rewards"], dataset["terminals"])): cur_return += reward_scale * reward cur_ids.append(i) reward_scale *= discount @@ -291,19 +269,13 @@ def __init__( self.net = nn.Sequential( nn.Linear(state_dim, 256), - nn.LayerNorm(256, elementwise_affine=False) - if actor_LN - else nn.Identity(), + nn.LayerNorm(256, elementwise_affine=False) if actor_LN else nn.Identity(), nn.ReLU(), nn.Linear(256, 256), - nn.LayerNorm(256, elementwise_affine=False) - if actor_LN - else nn.Identity(), + nn.LayerNorm(256, elementwise_affine=False) if actor_LN else nn.Identity(), nn.ReLU(), nn.Linear(256, 256), - nn.LayerNorm(256, elementwise_affine=False) - if actor_LN - else nn.Identity(), + nn.LayerNorm(256, elementwise_affine=False) if actor_LN else nn.Identity(), nn.ReLU(), # nn.Linear(256, action_dim), # nn.Tanh(), @@ -313,9 +285,7 @@ def __init__( self.max_action = max_action - def forward( - self, state: torch.Tensor, deterministic=False - ) -> torch.Tensor: + def forward(self, state: torch.Tensor, deterministic=False) -> torch.Tensor: if self.soft: hidden = self.net(state) mu = self.mu(hidden) @@ -330,9 +300,9 @@ def forward( tanh_action = torch.tanh(action) # change of variables formula (SAC paper, appendix C, eq 21) log_prob = policy_dist.log_prob(action).sum(axis=-1) - log_prob = log_prob - torch.log( - 1 - tanh_action.pow(2) + 1e-6 - ).sum(axis=-1) + log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum( + axis=-1 + ) scaled_action = self.max_action * tanh_action else: hidden = self.net(state) @@ -340,26 +310,20 @@ def forward( log_prob = None return scaled_action, log_prob - def log_prob( - self, state: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: + def log_prob(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: hidden = self.net(state) mu = self.mu(hidden) log_sigma = self.log_sigma(hidden) log_sigma = torch.clip(log_sigma, -5, 2) policy_dist = Normal(mu, torch.exp(log_sigma)) - action = torch.clip( - action, -self.max_action + 1e-6, self.max_action - 1e-6 - ) + action = torch.clip(action, -self.max_action + 1e-6, self.max_action - 1e-6) log_prob = policy_dist.log_prob(torch.arctanh(action)).sum(axis=-1) log_prob = log_prob - torch.log(1 - action.pow(2) + 1e-6).sum(axis=-1) return log_prob @torch.no_grad() def act(self, state: np.ndarray, device: str = "cpu") -> np.ndarray: - state = torch.tensor( - state.reshape(1, -1), device=device, dtype=torch.float32 - ) + state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32) return self(state, deterministic=True)[0].cpu().data.numpy().flatten() @@ -392,9 +356,7 @@ def _alpha_loss(self, state: torch.Tensor) -> torch.Tensor: with torch.no_grad(): action, action_log_prob = self.actor(state) - loss = ( - -self.log_alpha * (action_log_prob + self.target_entropy) - ).mean() + loss = (-self.log_alpha * (action_log_prob + self.target_entropy)).mean() return loss @@ -453,9 +415,7 @@ def train(config: TrainConfig): keep_best_trajectories(dataset, config.frac, config.discount) if config.normalize: - state_mean, state_std = compute_mean_std( - dataset["observations"], eps=1e-3 - ) + state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3) else: state_mean, state_std = 0, 1 @@ -463,9 +423,7 @@ def train(config: TrainConfig): dataset["next_observations"] = np.roll( dataset["observations"], shift=-1, axis=0 ) # Terminals/timeouts block next observations - print( - "Loaded next state observations from current state observations." - ) + print("Loaded next state observations from current state observations.") dataset["observations"] = normalize_states( dataset["observations"], state_mean, state_std @@ -485,9 +443,7 @@ def train(config: TrainConfig): if config.checkpoints_path is not None: print(f"Checkpoints path: {config.checkpoints_path}") os.makedirs(config.checkpoints_path, exist_ok=True) - with open( - os.path.join(config.checkpoints_path, "config.yaml"), "w" - ) as f: + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: pyrallis.dump(config, f) max_action = float(env.action_space.high[0]) @@ -496,9 +452,9 @@ def train(config: TrainConfig): seed = config.seed set_seed(seed, env) - actor = Actor( - state_dim, action_dim, max_action, config.actor_LN, config.soft - ).to(config.device) + actor = Actor(state_dim, action_dim, max_action, config.actor_LN, config.soft).to( + config.device + ) actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4) kwargs = { @@ -547,15 +503,9 @@ def train(config: TrainConfig): "epoch": int((t + 1) / 1000), } if hasattr(env, "get_normalized_score"): - normalized_score = ( - env.get_normalized_score(eval_returns) * 100.0 - ) - eval_log["eval/normalized_score_mean"] = np.mean( - normalized_score - ) - eval_log["eval/normalized_score_std"] = np.std( - normalized_score - ) + normalized_score = env.get_normalized_score(eval_returns) * 100.0 + eval_log["eval/normalized_score_mean"] = np.mean(normalized_score) + eval_log["eval/normalized_score_std"] = np.std(normalized_score) wandb.log(eval_log) print("---------------------------------------") diff --git a/algorithms/awac.py b/algorithms/awac.py index 745058f..f46cbd8 100644 --- a/algorithms/awac.py +++ b/algorithms/awac.py @@ -47,9 +47,7 @@ class TrainConfig: def __post_init__(self): self.name = f"{self.name}-{self.env_name}-{str(uuid.uuid4())[:8]}" if self.checkpoints_path is not None: - self.checkpoints_path = os.path.join( - self.checkpoints_path, self.name - ) + self.checkpoints_path = os.path.join(self.checkpoints_path, self.name) class ReplayBuffer: @@ -70,15 +68,11 @@ def __init__( self._actions = torch.zeros( (buffer_size, action_dim), dtype=torch.float32, device=device ) - self._rewards = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._next_states = torch.zeros( (buffer_size, state_dim), dtype=torch.float32, device=device ) - self._dones = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._device = device def _to_tensor(self, data: np.ndarray) -> torch.Tensor: @@ -86,9 +80,7 @@ def _to_tensor(self, data: np.ndarray) -> torch.Tensor: def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): if self._size != 0: - raise ValueError( - "Trying to load data into non-empty replay buffer" - ) + raise ValueError("Trying to load data into non-empty replay buffer") n_transitions = data["observations"].shape[0] if n_transitions > self._buffer_size: raise ValueError( @@ -96,24 +88,16 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): ) self._states[:n_transitions] = self._to_tensor(data["observations"]) self._actions[:n_transitions] = self._to_tensor(data["actions"]) - self._rewards[:n_transitions] = self._to_tensor( - data["rewards"][..., None] - ) - self._next_states[:n_transitions] = self._to_tensor( - data["next_observations"] - ) - self._dones[:n_transitions] = self._to_tensor( - data["terminals"][..., None] - ) + self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None]) + self._next_states[:n_transitions] = self._to_tensor(data["next_observations"]) + self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None]) self._size += n_transitions self._pointer = min(self._size, n_transitions) print(f"Dataset size: {n_transitions}") def sample(self, batch_size: int) -> TensorBatch: - indices = np.random.randint( - 0, min(self._size, self._pointer), size=batch_size - ) + indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size) states = self._states[indices] actions = self._actions[indices] rewards = self._rewards[indices] @@ -148,32 +132,24 @@ def __init__( nn.ReLU(), nn.Linear(hidden_dim, action_dim), ) - self._log_std = nn.Parameter( - torch.zeros(action_dim, dtype=torch.float32) - ) + self._log_std = nn.Parameter(torch.zeros(action_dim, dtype=torch.float32)) self._min_log_std = min_log_std self._max_log_std = max_log_std self._min_action = min_action self._max_action = max_action - def _get_policy( - self, state: torch.Tensor - ) -> torch.distributions.Distribution: + def _get_policy(self, state: torch.Tensor) -> torch.distributions.Distribution: mean = self._mlp(state) log_std = self._log_std.clamp(self._min_log_std, self._max_log_std) policy = torch.distributions.Normal(mean, log_std.exp()) return policy - def log_prob( - self, state: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: + def log_prob(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: policy = self._get_policy(state) log_prob = policy.log_prob(action).sum(-1, keepdim=True) return log_prob - def forward( - self, state: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: policy = self._get_policy(state) action = policy.rsample() action.clamp_(self._min_action, self._max_action) @@ -209,20 +185,14 @@ def __init__( nn.Linear(hidden_dim, 1), ) - def forward( - self, state: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: + def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: q_value = self._mlp(torch.cat([state, action], dim=-1)) return q_value def soft_update(target: nn.Module, source: nn.Module, tau: float): - for target_param, source_param in zip( - target.parameters(), source.parameters() - ): - target_param.data.copy_( - (1 - tau) * target_param.data + tau * source_param.data - ) + for target_param, source_param in zip(target.parameters(), source.parameters()): + target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data) class AdvantageWeightedActorCritic: @@ -312,9 +282,7 @@ def _update_actor(self, states, actions): def update(self, batch: TensorBatch) -> Dict[str, float]: states, actions, rewards, next_states, dones = batch - critic_loss = self._update_critic( - states, actions, rewards, dones, next_states - ) + critic_loss = self._update_critic(states, actions, rewards, dones, next_states) actor_loss = self._update_actor(states, actions) soft_update(self._target_critic_1, self._critic_1, self._tau) @@ -349,9 +317,7 @@ def set_seed( torch.use_deterministic_algorithms(deterministic_torch) -def compute_mean_std( - states: np.ndarray, eps: float -) -> Tuple[np.ndarray, np.ndarray]: +def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]: mean = states.mean(0) std = states.std(0) + eps return mean, std @@ -463,19 +429,13 @@ def train(config: TrainConfig): actor = Actor(**actor_critic_kwargs) actor.to(config.device) - actor_optimizer = torch.optim.Adam( - actor.parameters(), lr=config.learning_rate - ) + actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.learning_rate) critic_1 = Critic(**actor_critic_kwargs) critic_2 = Critic(**actor_critic_kwargs) critic_1.to(config.device) critic_2.to(config.device) - critic_1_optimizer = torch.optim.Adam( - critic_1.parameters(), lr=config.learning_rate - ) - critic_2_optimizer = torch.optim.Adam( - critic_2.parameters(), lr=config.learning_rate - ) + critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=config.learning_rate) + critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=config.learning_rate) awac = AdvantageWeightedActorCritic( actor=actor, @@ -493,9 +453,7 @@ def train(config: TrainConfig): if config.checkpoints_path is not None: print(f"Checkpoints path: {config.checkpoints_path}") os.makedirs(config.checkpoints_path, exist_ok=True) - with open( - os.path.join(config.checkpoints_path, "config.yaml"), "w" - ) as f: + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: pyrallis.dump(config, f) full_eval_scores, full_normalized_eval_scores = [], [] @@ -515,9 +473,7 @@ def train(config: TrainConfig): full_eval_scores.append(eval_scores) wandb.log({"eval_score": eval_scores.mean()}, step=t) if hasattr(env, "get_normalized_score"): - normalized_eval_scores = ( - env.get_normalized_score(eval_scores) * 100.0 - ) + normalized_eval_scores = env.get_normalized_score(eval_scores) * 100.0 full_normalized_eval_scores.append(normalized_eval_scores) wandb.log( {"normalized_eval_score": normalized_eval_scores.mean()}, @@ -528,17 +484,13 @@ def train(config: TrainConfig): os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"), ) - with open( - os.path.join(config.checkpoints_path, "/eval_scores.npy"), "wb" - ) as f: + with open(os.path.join(config.checkpoints_path, "/eval_scores.npy"), "wb") as f: # noinspection PyTypeChecker np.save(f, np.asarray(full_eval_scores)) if len(full_normalized_eval_scores) > 0: with open( - os.path.join( - config.checkpoints_path, "/normalized_eval_scores.npy" - ), + os.path.join(config.checkpoints_path, "/normalized_eval_scores.npy"), "wb", ) as f: # noinspection PyTypeChecker diff --git a/algorithms/cql.py b/algorithms/cql.py index c6b2dab..6f03e52 100644 --- a/algorithms/cql.py +++ b/algorithms/cql.py @@ -76,17 +76,11 @@ def __post_init__(self): def soft_update(target: nn.Module, source: nn.Module, tau: float): - for target_param, source_param in zip( - target.parameters(), source.parameters() - ): - target_param.data.copy_( - (1 - tau) * target_param.data + tau * source_param.data - ) + for target_param, source_param in zip(target.parameters(), source.parameters()): + target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data) -def compute_mean_std( - states: np.ndarray, eps: float -) -> Tuple[np.ndarray, np.ndarray]: +def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]: mean = states.mean(0) std = states.std(0) + eps return mean, std @@ -105,9 +99,7 @@ def discount_cumsum(x, discount, include_first=True): else: disc_cumsum[-1] = 0 for t in reversed(range(x.shape[0] - 1)): - disc_cumsum[t] = ( - discount * x[t + 1] + discount * disc_cumsum[t + 1] - ) + disc_cumsum[t] = discount * x[t + 1] + discount * disc_cumsum[t + 1] return disc_cumsum @@ -152,9 +144,7 @@ def __init__( self._actions = torch.zeros( (buffer_size, action_dim), dtype=torch.float32, device=device ) - self._rewards = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._returns_to_go = torch.zeros( (buffer_size, 1), dtype=torch.float32, device=device ) @@ -173,9 +163,7 @@ def __init__( self._next_states = torch.zeros( (buffer_size, state_dim), dtype=torch.float32, device=device ) - self._dones = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._discount = discount self._device = device @@ -192,9 +180,7 @@ def compute_returns_to_go(self, data: np.ndarray) -> np.ndarray: data["rewards"][i] ) # - 1* self._action_dim* torch.log(1 / torch.sqrt(torch.tensor(2 * np.pi)))) if ( - data["terminals"][i] - or data["timeouts"][i] - or i == n_transitions - 1 + data["terminals"][i] or data["timeouts"][i] or i == n_transitions - 1 ): # TODO: Deal with incomplete trajectory case episode_returns_to_go = discount_cumsum( np.array(episode_rewards), self._discount @@ -215,9 +201,7 @@ def compute_returns_to_go(self, data: np.ndarray) -> np.ndarray: ) # Terminals/timeouts block next returns to go assert next_returns_to_go[0] == returns_to_go[1] - self._returns_to_go[:n_transitions] = self._to_tensor( - returns_to_go[..., None] - ) + self._returns_to_go[:n_transitions] = self._to_tensor(returns_to_go[..., None]) self._next_returns_to_go[:n_transitions] = self._to_tensor( next_returns_to_go[..., None] ) @@ -248,9 +232,7 @@ def compute_soft_returns_to_go( for i in range(n_transitions): episode_rewards.append(self._rewards[i].cpu().item()) - episode_entropy_bonuses.append( - self._entropy_bonuses[i].cpu().item() - ) + episode_entropy_bonuses.append(self._entropy_bonuses[i].cpu().item()) if self._dones[i] or i == n_transitions - 1: episode_returns_to_go = discount_cumsum( np.array(episode_rewards), self._discount @@ -287,9 +269,7 @@ def compute_soft_returns_to_go( # Loads data in d4rl format, i.e. from Dict[str, np.array]. def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): if self._size != 0: - raise ValueError( - "Trying to load data into non-empty replay buffer" - ) + raise ValueError("Trying to load data into non-empty replay buffer") n_transitions = data["observations"].shape[0] if n_transitions > self._buffer_size: raise ValueError( @@ -297,15 +277,9 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): ) self._states[:n_transitions] = self._to_tensor(data["observations"]) self._actions[:n_transitions] = self._to_tensor(data["actions"]) - self._rewards[:n_transitions] = self._to_tensor( - data["rewards"][..., None] - ) - self._next_states[:n_transitions] = self._to_tensor( - data["next_observations"] - ) - self._dones[:n_transitions] = self._to_tensor( - data["terminals"][..., None] - ) + self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None]) + self._next_states[:n_transitions] = self._to_tensor(data["next_observations"]) + self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None]) self._size += n_transitions self._pointer = min(self._size, n_transitions) @@ -315,9 +289,7 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): print(f"Dataset size: {n_transitions}") def sample(self, batch_size: int) -> TensorBatch: - indices = np.random.randint( - 0, min(self._size, self._pointer), size=batch_size - ) + indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size) states = self._states[indices] actions = self._actions[indices] rewards = self._rewards[indices] @@ -432,15 +404,11 @@ def modify_reward(dataset, env_name, max_episode_steps=1000): dataset["rewards"] -= 1.0 -def extend_and_repeat( - tensor: torch.Tensor, dim: int, repeat: int -) -> torch.Tensor: +def extend_and_repeat(tensor: torch.Tensor, dim: int, repeat: int) -> torch.Tensor: return tensor.unsqueeze(dim).repeat_interleave(repeat, dim=dim) -def init_module_weights( - module: torch.nn.Module, orthogonal_init: bool = False -): +def init_module_weights(module: torch.nn.Module, orthogonal_init: bool = False): if isinstance(module, nn.Linear): if orthogonal_init: nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) @@ -495,9 +463,7 @@ def forward( else: action_sample = action_distribution.rsample() - log_prob = torch.sum( - action_distribution.log_prob(action_sample), dim=-1 - ) + log_prob = torch.sum(action_distribution.log_prob(action_sample), dim=-1) return action_sample, log_prob @@ -545,15 +511,11 @@ def __init__( def log_prob( self, observations: torch.Tensor, actions: torch.Tensor ) -> torch.Tensor: - actions = torch.clip( - actions, -self.max_action + 1e-6, self.max_action - 1e-6 - ) + actions = torch.clip(actions, -self.max_action + 1e-6, self.max_action - 1e-6) if actions.ndim == 3: observations = extend_and_repeat(observations, 1, actions.shape[1]) base_network_output = self.base_network(observations) - mean, log_std = torch.split( - base_network_output, self.action_dim, dim=-1 - ) + mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1) log_std = self.log_std_multiplier() * log_std + self.log_std_offset() return self.tanh_gaussian.log_prob(mean, log_std, actions) @@ -566,18 +528,14 @@ def forward( if repeat is not None: observations = extend_and_repeat(observations, 1, repeat) base_network_output = self.base_network(observations) - mean, log_std = torch.split( - base_network_output, self.action_dim, dim=-1 - ) + mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1) log_std = self.log_std_multiplier() * log_std + self.log_std_offset() actions, log_probs = self.tanh_gaussian(mean, log_std, deterministic) return self.max_action * actions, log_probs @torch.no_grad() def act(self, state: np.ndarray, device: str = "cpu"): - state = torch.tensor( - state.reshape(1, -1), device=device, dtype=torch.float32 - ) + state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32) with torch.no_grad(): actions, _ = self(state, not self.training) return actions.cpu().data.numpy().flatten() @@ -612,16 +570,14 @@ def __init__( else: init_module_weights(self.network[-1], False) - def forward( - self, observations: torch.Tensor, actions: torch.Tensor - ) -> torch.Tensor: + def forward(self, observations: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: multiple_actions = False batch_size = observations.shape[0] if actions.ndim == 3 and observations.ndim == 2: multiple_actions = True - observations = extend_and_repeat( - observations, 1, actions.shape[1] - ).reshape(-1, observations.shape[-1]) + observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape( + -1, observations.shape[-1] + ) actions = actions.reshape(-1, actions.shape[-1]) input_tensor = torch.cat([observations, actions], dim=-1) q_values = torch.squeeze(self.network(input_tensor), dim=-1) @@ -633,9 +589,7 @@ def forward( class Scalar(nn.Module): def __init__(self, init_value: float): super().__init__() - self.constant = nn.Parameter( - torch.tensor(init_value, dtype=torch.float32) - ) + self.constant = nn.Parameter(torch.tensor(init_value, dtype=torch.float32)) def forward(self) -> nn.Parameter: return self.constant @@ -727,16 +681,10 @@ def __init__( self.total_it = 0 def update_target_network(self, soft_target_update_rate: float): - soft_update( - self.target_critic_1, self.critic_1, soft_target_update_rate - ) - soft_update( - self.target_critic_2, self.critic_2, soft_target_update_rate - ) + soft_update(self.target_critic_1, self.critic_1, soft_target_update_rate) + soft_update(self.target_critic_2, self.critic_2, soft_target_update_rate) - def _alpha_and_alpha_loss( - self, observations: torch.Tensor, log_pi: torch.Tensor - ): + def _alpha_and_alpha_loss(self, observations: torch.Tensor, log_pi: torch.Tensor): if self.use_automatic_entropy_tuning: alpha_loss = -( self.log_alpha() * (log_pi + self.target_entropy).detach() @@ -833,12 +781,8 @@ def _q_loss( cql_q1_rand = self.critic_1(observations, cql_random_actions) cql_q2_rand = self.critic_2(observations, cql_random_actions) - cql_q1_current_actions = self.critic_1( - observations, cql_current_actions - ) - cql_q2_current_actions = self.critic_2( - observations, cql_current_actions - ) + cql_q1_current_actions = self.critic_1(observations, cql_current_actions) + cql_q2_current_actions = self.critic_2(observations, cql_current_actions) cql_q1_next_actions = self.critic_1(observations, cql_next_actions) cql_q2_next_actions = self.critic_2(observations, cql_next_actions) @@ -882,12 +826,8 @@ def _q_loss( dim=1, ) - cql_qf1_ood = ( - torch.logsumexp(cql_cat_q1 / self.cql_temp, dim=1) * self.cql_temp - ) - cql_qf2_ood = ( - torch.logsumexp(cql_cat_q2 / self.cql_temp, dim=1) * self.cql_temp - ) + cql_qf1_ood = torch.logsumexp(cql_cat_q1 / self.cql_temp, dim=1) * self.cql_temp + cql_qf2_ood = torch.logsumexp(cql_cat_q2 / self.cql_temp, dim=1) * self.cql_temp """Subtract the log likelihood of data""" cql_qf1_diff = torch.clamp( @@ -1017,16 +957,11 @@ def pretrain_softC(self, batch): next_action, next_log_prob = self.actor(next_observations) q_next = ( - next_return_to_go - - self.log_alpha().exp() * next_log_prob.unsqueeze(-1) + next_return_to_go - self.log_alpha().exp() * next_log_prob.unsqueeze(-1) ) - q_target = ( - rewards + self.discount * (1 - dones) * q_next - ).squeeze(-1) + q_target = (rewards + self.discount * (1 - dones) * q_next).squeeze(-1) - qf_loss = F.mse_loss(q1_values, q_target) + F.mse_loss( - q2_values, q_target - ) + qf_loss = F.mse_loss(q1_values, q_target) + F.mse_loss(q2_values, q_target) log_dict = dict( qf_loss=qf_loss.item(), @@ -1125,12 +1060,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]): self.critic_1.load_state_dict(state_dict=state_dict["critic1"]) self.critic_2.load_state_dict(state_dict=state_dict["critic2"]) - self.target_critic_1.load_state_dict( - state_dict=state_dict["critic1_target"] - ) - self.target_critic_2.load_state_dict( - state_dict=state_dict["critic2_target"] - ) + self.target_critic_1.load_state_dict(state_dict=state_dict["critic1_target"]) + self.target_critic_2.load_state_dict(state_dict=state_dict["critic2_target"]) self.critic_1_optimizer.load_state_dict( state_dict=state_dict["critic_1_optimizer"] @@ -1138,9 +1069,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]): self.critic_2_optimizer.load_state_dict( state_dict=state_dict["critic_2_optimizer"] ) - self.actor_optimizer.load_state_dict( - state_dict=state_dict["actor_optim"] - ) + self.actor_optimizer.load_state_dict(state_dict=state_dict["actor_optim"]) self.log_alpha = state_dict["sac_log_alpha"] self.alpha_optimizer.load_state_dict( @@ -1167,9 +1096,7 @@ def train(config: TrainConfig): modify_reward(dataset, config.env) if config.normalize: - state_mean, state_std = compute_mean_std( - dataset["observations"], eps=1e-3 - ) + state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3) else: state_mean, state_std = 0, 1 @@ -1177,9 +1104,7 @@ def train(config: TrainConfig): dataset["next_observations"] = np.roll( dataset["observations"], shift=-1, axis=0 ) # Terminals/timeouts block next observations - print( - "Loaded next state observations from current state observations." - ) + print("Loaded next state observations from current state observations.") dataset["observations"] = normalize_states( dataset["observations"], state_mean, state_std @@ -1202,27 +1127,21 @@ def train(config: TrainConfig): if config.checkpoints_path is not None: print(f"Checkpoints path: {config.checkpoints_path}") os.makedirs(config.checkpoints_path, exist_ok=True) - with open( - os.path.join(config.checkpoints_path, "config.yaml"), "w" - ) as f: + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: pyrallis.dump(config, f) # Set seeds seed = config.seed set_seed(seed, env) - critic_1 = FullyConnectedQFunction( - state_dim, action_dim, config.orthogonal_init - ).to(config.device) - critic_2 = FullyConnectedQFunction( - state_dim, action_dim, config.orthogonal_init - ).to(config.device) - critic_1_optimizer = torch.optim.Adam( - list(critic_1.parameters()), config.qf_lr + critic_1 = FullyConnectedQFunction(state_dim, action_dim, config.orthogonal_init).to( + config.device ) - critic_2_optimizer = torch.optim.Adam( - list(critic_2.parameters()), config.qf_lr + critic_2 = FullyConnectedQFunction(state_dim, action_dim, config.orthogonal_init).to( + config.device ) + critic_1_optimizer = torch.optim.Adam(list(critic_1.parameters()), config.qf_lr) + critic_2_optimizer = torch.optim.Adam(list(critic_2.parameters()), config.qf_lr) actor = TanhGaussianPolicy( state_dim, @@ -1313,15 +1232,9 @@ def train(config: TrainConfig): "epoch": int((t + 1) / 1000), } if hasattr(env, "get_normalized_score"): - normalized_score = ( - env.get_normalized_score(eval_returns) * 100.0 - ) - eval_log["eval/normalized_score_mean"] = np.mean( - normalized_score - ) - eval_log["eval/normalized_score_std"] = np.std( - normalized_score - ) + normalized_score = env.get_normalized_score(eval_returns) * 100.0 + eval_log["eval/normalized_score_mean"] = np.mean(normalized_score) + eval_log["eval/normalized_score_std"] = np.std(normalized_score) wandb.log(eval_log) print("---------------------------------------") diff --git a/algorithms/dt.py b/algorithms/dt.py index e05ac1d..10996b2 100644 --- a/algorithms/dt.py +++ b/algorithms/dt.py @@ -61,9 +61,7 @@ class TrainConfig: def __post_init__(self): self.name = f"{self.name}-{self.env_name}-{str(uuid.uuid4())[:8]}" if self.checkpoints_path is not None: - self.checkpoints_path = os.path.join( - self.checkpoints_path, self.name - ) + self.checkpoints_path = os.path.join(self.checkpoints_path, self.name) # general utils @@ -119,9 +117,7 @@ def pad_along_axis( npad = [(0, 0)] * arr.ndim npad[axis] = (0, pad_size) - return np.pad( - arr, pad_width=npad, mode="constant", constant_values=fill_value - ) + return np.pad(arr, pad_width=npad, mode="constant", constant_values=fill_value) def discounted_cumsum(x: np.ndarray, gamma: float) -> np.ndarray: @@ -139,17 +135,13 @@ def load_d4rl_trajectories( traj, traj_len = [], [] data_, episode_step = defaultdict(list), 0 - for i in trange( - dataset["rewards"].shape[0], desc="Processing trajectories" - ): + for i in trange(dataset["rewards"].shape[0], desc="Processing trajectories"): data_["observations"].append(dataset["observations"][i]) data_["actions"].append(dataset["actions"][i]) data_["rewards"].append(dataset["rewards"][i]) if dataset["terminals"][i] or dataset["timeouts"][i]: - episode_data = { - k: np.array(v, dtype=np.float32) for k, v in data_.items() - } + episode_data = {k: np.array(v, dtype=np.float32) for k, v in data_.items()} # return-to-go if gamma=1.0, just discounted returns else episode_data["returns"] = discounted_cumsum( episode_data["rewards"], gamma=gamma @@ -171,9 +163,7 @@ def load_d4rl_trajectories( class SequenceDataset(IterableDataset): - def __init__( - self, env_name: str, seq_len: int = 10, reward_scale: float = 1.0 - ): + def __init__(self, env_name: str, seq_len: int = 10, reward_scale: float = 1.0): self.dataset, info = load_d4rl_trajectories(env_name, gamma=1.0) self.reward_scale = reward_scale self.seq_len = seq_len @@ -210,9 +200,7 @@ def __prepare_sample(self, traj_idx, start_idx): def __iter__(self): while True: traj_idx = np.random.choice(len(self.dataset), p=self.sample_prob) - start_idx = random.randint( - 0, self.dataset[traj_idx]["rewards"].shape[0] - 1 - ) + start_idx = random.randint(0, self.dataset[traj_idx]["rewards"].shape[0] - 1) yield self.__prepare_sample(traj_idx, start_idx) @@ -307,9 +295,7 @@ def __init__( for _ in range(num_layers) ] ) - self.action_head = nn.Sequential( - nn.Linear(embedding_dim, action_dim), nn.Tanh() - ) + self.action_head = nn.Sequential(nn.Linear(embedding_dim, action_dim), nn.Tanh()) self.seq_len = seq_len self.embedding_dim = embedding_dim self.state_dim = state_dim @@ -394,12 +380,8 @@ def eval_rollout( dtype=torch.float, device=device, ) - returns = torch.zeros( - 1, model.episode_len + 1, dtype=torch.float, device=device - ) - time_steps = torch.arange( - model.episode_len, dtype=torch.long, device=device - ) + returns = torch.zeros(1, model.episode_len + 1, dtype=torch.float, device=device) + time_steps = torch.arange(model.episode_len, dtype=torch.long, device=device) time_steps = time_steps.view(1, -1) states[:, 0] = torch.as_tensor(env.reset(), device=device) @@ -489,18 +471,14 @@ def train(config: TrainConfig): if config.checkpoints_path is not None: print(f"Checkpoints path: {config.checkpoints_path}") os.makedirs(config.checkpoints_path, exist_ok=True) - with open( - os.path.join(config.checkpoints_path, "config.yaml"), "w" - ) as f: + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: pyrallis.dump(config, f) print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") trainloader_iter = iter(trainloader) for step in trange(config.update_steps, desc="Training"): batch = next(trainloader_iter) - states, actions, returns, time_steps, mask = [ - b.to(config.device) for b in batch - ] + states, actions, returns, time_steps, mask = [b.to(config.device) for b in batch] # True value indicates that the corresponding key value will be ignored padding_mask = ~mask.to(torch.bool) @@ -511,18 +489,14 @@ def train(config: TrainConfig): time_steps=time_steps, padding_mask=padding_mask, ) - loss = F.mse_loss( - predicted_actions, actions.detach(), reduction="none" - ) + loss = F.mse_loss(predicted_actions, actions.detach(), reduction="none") # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1] loss = (loss * mask.unsqueeze(-1)).mean() optim.zero_grad() loss.backward() if config.clip_grad is not None: - torch.nn.utils.clip_grad_norm_( - model.parameters(), config.clip_grad - ) + torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad) optim.step() scheduler.step() @@ -540,9 +514,7 @@ def train(config: TrainConfig): for target_return in config.target_returns: eval_env.seed(config.eval_seed) eval_returns = [] - for _ in trange( - config.eval_episodes, desc="Evaluation", leave=False - ): + for _ in trange(config.eval_episodes, desc="Evaluation", leave=False): eval_return, eval_len = eval_rollout( model=model, env=eval_env, @@ -557,12 +529,8 @@ def train(config: TrainConfig): ) wandb.log( { - f"eval/{target_return}_return_mean": np.mean( - eval_returns - ), - f"eval/{target_return}_return_std": np.std( - eval_returns - ), + f"eval/{target_return}_return_mean": np.mean(eval_returns), + f"eval/{target_return}_return_std": np.std(eval_returns), f"eval/{target_return}_normalized_score_mean": np.mean( normalized_scores ), diff --git a/algorithms/edac.py b/algorithms/edac.py index a0acab4..e487fcd 100644 --- a/algorithms/edac.py +++ b/algorithms/edac.py @@ -89,12 +89,8 @@ def __post_init__(self): def soft_update(target: nn.Module, source: nn.Module, tau: float): - for target_param, source_param in zip( - target.parameters(), source.parameters() - ): - target_param.data.copy_( - (1 - tau) * target_param.data + tau * source_param.data - ) + for target_param, source_param in zip(target.parameters(), source.parameters()): + target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data) def wandb_init(config: dict) -> None: @@ -121,9 +117,7 @@ def set_seed( torch.use_deterministic_algorithms(deterministic_torch) -def compute_mean_std( - states: np.ndarray, eps: float -) -> Tuple[np.ndarray, np.ndarray]: +def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]: mean = states.mean(0) std = states.std(0) + eps return mean, std @@ -145,9 +139,7 @@ def discount_cumsum(x, discount, include_first=True): else: disc_cumsum[-1] = 0 for t in reversed(range(x.shape[0] - 1)): - disc_cumsum[t] = ( - discount * x[t + 1] + discount * disc_cumsum[t + 1] - ) + disc_cumsum[t] = discount * x[t + 1] + discount * disc_cumsum[t + 1] return disc_cumsum @@ -190,9 +182,7 @@ def __init__( self._actions = torch.zeros( (buffer_size, action_dim), dtype=torch.float32, device=device ) - self._rewards = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._returns_to_go = torch.zeros( (buffer_size, 1), dtype=torch.float32, device=device ) @@ -208,9 +198,7 @@ def __init__( self._timeouts = torch.zeros( (buffer_size, 1), dtype=torch.float32, device=device ) - self._dones = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._discount = discount self._batch_size = batch_size self._device = device @@ -225,11 +213,7 @@ def compute_returns_to_go(self, data: np.ndarray): for i in range(n_transitions): episode_rewards.append(data["rewards"][i]) - if ( - data["terminals"][i] - or data["timeouts"][i] - or i == n_transitions - 1 - ): + if data["terminals"][i] or data["timeouts"][i] or i == n_transitions - 1: episode_returns_to_go = discount_cumsum( np.array(episode_rewards), self._discount ) @@ -244,9 +228,7 @@ def compute_returns_to_go(self, data: np.ndarray): ] ).flatten() - self._returns_to_go[:n_transitions] = self._to_tensor( - returns_to_go[..., None] - ) + self._returns_to_go[:n_transitions] = self._to_tensor(returns_to_go[..., None]) def compute_soft_returns_to_go(self, alpha: torch.Tensor, actor: "Actor"): n_transitions = self._states.shape[0] @@ -277,12 +259,10 @@ def compute_soft_returns_to_go(self, alpha: torch.Tensor, actor: "Actor"): episode_entropy_bonuses.append(self._entropy_bonuses[i].item()) if self._dones[i] or i == n_transitions - 1: if self._timeouts[i] or i == n_transitions - 1: - episode_rewards[-1] = episode_rewards[-1] / ( + episode_rewards[-1] = episode_rewards[-1] / (1 - self._discount) + episode_entropy_bonuses[-1] = episode_entropy_bonuses[-1] / ( 1 - self._discount ) - episode_entropy_bonuses[-1] = episode_entropy_bonuses[ - -1 - ] / (1 - self._discount) episode_returns_to_go = discount_cumsum( np.array(episode_rewards), self._discount ) + alpha.detach().item() * discount_cumsum( @@ -310,9 +290,7 @@ def compute_soft_returns_to_go(self, alpha: torch.Tensor, actor: "Actor"): # Loads data in d4rl format, i.e. from Dict[str, np.array]. def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): if self._size != 0: - raise ValueError( - "Trying to load data into non-empty replay buffer" - ) + raise ValueError("Trying to load data into non-empty replay buffer") n_transitions = data["observations"].shape[0] if n_transitions > self._buffer_size: raise ValueError( @@ -321,16 +299,10 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): self._states[:n_transitions] = self._to_tensor(data["observations"]) self._actions[:n_transitions] = self._to_tensor(data["actions"]) - self._rewards[:n_transitions] = self._to_tensor( - data["rewards"][..., None] - ) - self._next_states[:n_transitions] = self._to_tensor( - data["next_observations"] - ) + self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None]) + self._next_states[:n_transitions] = self._to_tensor(data["next_observations"]) - self._timeouts[:n_transitions] = self._to_tensor( - data["timeouts"][..., None] - ) + self._timeouts[:n_transitions] = self._to_tensor(data["timeouts"][..., None]) self._dones[:n_transitions] = self._to_tensor( (data["terminals"] + data["timeouts"])[..., None] ) @@ -343,9 +315,7 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): print(f"Dataset size: {n_transitions}") def sample(self, batch_size: int) -> TensorBatch: - indices = np.random.randint( - 0, min(self._size, self._pointer), size=batch_size - ) + indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size) states = self._states[indices] actions = self._actions[indices] rewards = self._rewards[indices] @@ -371,17 +341,13 @@ def add_transition(self): # SAC Actor & Critic implementation class VectorizedLinear(nn.Module): - def __init__( - self, in_features: int, out_features: int, ensemble_size: int - ): + def __init__(self, in_features: int, out_features: int, ensemble_size: int): super().__init__() self.in_features = in_features self.out_features = out_features self.ensemble_size = ensemble_size - self.weight = nn.Parameter( - torch.empty(ensemble_size, in_features, out_features) - ) + self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features)) self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features)) self.reset_parameters() @@ -468,18 +434,14 @@ def forward( if need_log_prob: # change of variables formula (SAC paper, appendix C, eq 21) log_prob = policy_dist.log_prob(action).sum(axis=-1) - log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum( - axis=-1 - ) + log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum(axis=-1) if need_policy_dist: return tanh_action * self.max_action, log_prob, policy_dist return tanh_action * self.max_action, log_prob - def log_prob( - self, state: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: + def log_prob(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: hidden = self.trunk(state) mu, log_sigma = self.mu(hidden), self.log_sigma(hidden) @@ -487,9 +449,7 @@ def log_prob( log_sigma = torch.clip(log_sigma, -5, 2) policy_dist = Normal(mu, torch.exp(log_sigma)) - action = torch.clip( - action, -self.max_action + 1e-6, self.max_action - 1e-6 - ) + action = torch.clip(action, -self.max_action + 1e-6, self.max_action - 1e-6) log_prob = policy_dist.log_prob(torch.arctanh(action)).sum(axis=-1) log_prob = log_prob - torch.log(1 - action.pow(2) + 1e-6).sum(axis=-1) return log_prob @@ -539,9 +499,7 @@ def __init__( self.num_critics = num_critics - def forward( - self, state: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: + def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: # [..., batch_size, state_dim + action_dim] state_action = torch.cat([state, action], dim=-1) if state_action.dim() != 3: @@ -605,18 +563,14 @@ def __init__( self.log_alpha = torch.tensor( [0.0], dtype=torch.float32, device=self.device, requires_grad=True ) - self.alpha_optimizer = torch.optim.Adam( - [self.log_alpha], lr=alpha_learning_rate - ) + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_learning_rate) self.alpha = self.log_alpha.exp().detach() def _alpha_loss(self, state: torch.Tensor) -> torch.Tensor: with torch.no_grad(): action, action_log_prob = self.actor(state, need_log_prob=True) - loss = ( - -self.log_alpha * (action_log_prob + self.target_entropy) - ).mean() + loss = (-self.log_alpha * (action_log_prob + self.target_entropy)).mean() return loss @@ -644,8 +598,7 @@ def _actor_loss( bc_loss = F.mse_loss(pi, action) loss = (self.alpha * log_pi - q_value_min).mean() loss = ( - loss / loss.detach() - + self.bc_regulariser * bc_loss / bc_loss.detach() + loss / loss.detach() + self.bc_regulariser * bc_loss / bc_loss.detach() ) elif self.soft_bc_regulariser > 0.0: bc_loss = (self.alpha * log_pi - log_prob_action).mean() @@ -654,8 +607,7 @@ def _actor_loss( ) loss = (self.alpha * log_pi - q_value_min).mean() loss = ( - loss / loss.detach() - - self.soft_bc_regulariser * log_prob_action.mean() + loss / loss.detach() - self.soft_bc_regulariser * log_prob_action.mean() ) else: loss = (self.alpha * log_pi - q_value_min).mean() @@ -726,9 +678,7 @@ def _critic_loss( q_values = self.critic(state, action) # [ensemble_size, batch_size] - [1, batch_size] - critic_loss = ( - ((q_values - q_target.view(1, -1)) ** 2).mean(dim=1).sum(dim=0) - ) + critic_loss = ((q_values - q_target.view(1, -1)) ** 2).mean(dim=1).sum(dim=0) diversity_loss = self._critic_diversity_loss(state, action) loss = critic_loss + self.eta * diversity_loss @@ -784,28 +734,18 @@ def pretrain_soft_critic( next_action, next_action_log_prob = self.actor( next_state, need_log_prob=True ) - q_next = ( - self.target_critic(next_state, next_action).min(0).values - ) + q_next = self.target_critic(next_state, next_action).min(0).values q_next = q_next - self.alpha * next_action_log_prob assert q_next.unsqueeze(-1).shape == done.shape == reward.shape - q_target0 = reward + self.gamma * ( - 1 - done - ) * q_next.unsqueeze(-1) + q_target0 = reward + self.gamma * (1 - done) * q_next.unsqueeze(-1) # [ensemble_size, batch_size] - [1, batch_size] TD_critic_loss = ( - ((q_values - q_target0.view(1, -1)) ** 2) - .mean(dim=1) - .sum(dim=0) + ((q_values - q_target0.view(1, -1)) ** 2).mean(dim=1).sum(dim=0) ) critic_loss = ( - ( - (1 - self.td_component) - * MC_critic_loss - / MC_critic_loss.detach() - ) + ((1 - self.td_component) * MC_critic_loss / MC_critic_loss.detach()) + self.td_component * TD_critic_loss / TD_critic_loss.detach() + self.eta * diversity_loss ) @@ -822,13 +762,9 @@ def pretrain_soft_critic( # for logging, Q-ensemble std estimate with the random actions: # a ~ U[-max_action, max_action] max_action = self.actor.max_action - random_actions = -max_action + 2 * max_action * torch.rand_like( - action - ) + random_actions = -max_action + 2 * max_action * torch.rand_like(action) - q_random_std = ( - self.critic(state, random_actions).std(0).mean().item() - ) + q_random_std = self.critic(state, random_actions).std(0).mean().item() log_dict["q_random_std"] = q_random_std return log_dict @@ -869,13 +805,9 @@ def update(self, batch: TensorBatch) -> Dict[str, float]: # for logging, Q-ensemble std estimate with the random actions: # a ~ U[-max_action, max_action] max_action = self.actor.max_action - random_actions = -max_action + 2 * max_action * torch.rand_like( - action - ) + random_actions = -max_action + 2 * max_action * torch.rand_like(action) - q_random_std = ( - self.critic(state, random_actions).std(0).mean().item() - ) + q_random_std = self.critic(state, random_actions).std(0).mean().item() update_info = { "alpha_loss": alpha_loss.item(), @@ -1001,17 +933,13 @@ def train(config: TrainConfig): if config.normalize_reward: modify_reward(d4rl_dataset, config.env_name) - state_mean, state_std = compute_mean_std( - d4rl_dataset["observations"], eps=1e-3 - ) + state_mean, state_std = compute_mean_std(d4rl_dataset["observations"], eps=1e-3) if "next_observations" not in d4rl_dataset.keys(): d4rl_dataset["next_observations"] = np.roll( d4rl_dataset["observations"], shift=-1, axis=0 ) # Terminals/timeouts block next observations - print( - "Loaded next state observations from current state observations." - ) + print("Loaded next state observations from current state observations.") d4rl_dataset["observations"] = normalize_states( d4rl_dataset["observations"], state_mean, state_std @@ -1041,9 +969,7 @@ def train(config: TrainConfig): config.actor_LN, ) actor.to(config.device) - actor_optimizer = torch.optim.Adam( - actor.parameters(), lr=config.actor_learning_rate - ) + actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_learning_rate) critic = VectorizedCritic( state_dim, action_dim, @@ -1080,17 +1006,13 @@ def train(config: TrainConfig): if config.checkpoints_path is not None: print(f"Checkpoints path: {config.checkpoints_path}") os.makedirs(config.checkpoints_path, exist_ok=True) - with open( - os.path.join(config.checkpoints_path, "config.yaml"), "w" - ) as f: + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: pyrallis.dump(config, f) total_updates = 0.0 for epoch in trange(config.num_epochs, desc="Training"): # training - for _ in trange( - config.num_updates_on_epoch, desc="Epoch", leave=False - ): + for _ in trange(config.num_updates_on_epoch, desc="Epoch", leave=False): batch = buffer.sample(config.batch_size) if config.pretrain is not None: if epoch <= config.pretrain_epochs: @@ -1107,9 +1029,7 @@ def train(config: TrainConfig): alpha=trainer.alpha, actor=trainer.actor, ) - print( - "Soft returns to go loaded for BC actor!" - ) + print("Soft returns to go loaded for BC actor!") assert buffer._soft_returns_loaded == True update_info = trainer.pretrain_soft_critic( batch, epoch, config.pretrain_epochs @@ -1120,9 +1040,7 @@ def train(config: TrainConfig): alpha=trainer.alpha, actor=trainer.actor, ) - print( - "Soft returns to go loaded for initialised actor!" - ) + print("Soft returns to go loaded for initialised actor!") assert buffer._soft_returns_loaded == True update_info = trainer.pretrain_soft_critic( batch, epoch, config.pretrain_epochs @@ -1134,9 +1052,7 @@ def train(config: TrainConfig): else: if epoch == config.pretrain_epochs + 1: with torch.no_grad(): - trainer.pretrained_critic = deepcopy( - trainer.critic - ) + trainer.pretrained_critic = deepcopy(trainer.critic) trainer.pretrained_actor = deepcopy(trainer.actor) update_info = trainer.update(batch) else: @@ -1164,15 +1080,9 @@ def train(config: TrainConfig): "epoch": epoch, } if hasattr(eval_env, "get_normalized_score"): - normalized_score = ( - eval_env.get_normalized_score(eval_returns) * 100.0 - ) - eval_log["eval/normalized_score_mean"] = np.mean( - normalized_score - ) - eval_log["eval/normalized_score_std"] = np.std( - normalized_score - ) + normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0 + eval_log["eval/normalized_score_mean"] = np.mean(normalized_score) + eval_log["eval/normalized_score_std"] = np.std(normalized_score) wandb.log(eval_log) diff --git a/algorithms/iql.py b/algorithms/iql.py index 2f8ff49..44e31b8 100644 --- a/algorithms/iql.py +++ b/algorithms/iql.py @@ -43,9 +43,7 @@ class TrainConfig: batch_size: int = 256 # Batch size for all networks discount: float = 0.99 # Discount factor tau: float = 0.005 # Target network update rate - beta: float = ( - 3.0 # Inverse temperature. Small beta -> BC, big beta -> maximizing Q - ) + beta: float = 3.0 # Inverse temperature. Small beta -> BC, big beta -> maximizing Q iql_tau: float = 0.7 # Coefficient for asymmetric loss iql_deterministic: bool = False # Use deterministic actor normalize: bool = True # Normalize states @@ -58,23 +56,15 @@ class TrainConfig: def __post_init__(self): self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}" if self.checkpoints_path is not None: - self.checkpoints_path = os.path.join( - self.checkpoints_path, self.name - ) + self.checkpoints_path = os.path.join(self.checkpoints_path, self.name) def soft_update(target: nn.Module, source: nn.Module, tau: float): - for target_param, source_param in zip( - target.parameters(), source.parameters() - ): - target_param.data.copy_( - (1 - tau) * target_param.data + tau * source_param.data - ) + for target_param, source_param in zip(target.parameters(), source.parameters()): + target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data) -def compute_mean_std( - states: np.ndarray, eps: float -) -> Tuple[np.ndarray, np.ndarray]: +def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]: mean = states.mean(0) std = states.std(0) + eps return mean, std @@ -124,15 +114,11 @@ def __init__( self._actions = torch.zeros( (buffer_size, action_dim), dtype=torch.float32, device=device ) - self._rewards = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._next_states = torch.zeros( (buffer_size, state_dim), dtype=torch.float32, device=device ) - self._dones = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._device = device def _to_tensor(self, data: np.ndarray) -> torch.Tensor: @@ -141,9 +127,7 @@ def _to_tensor(self, data: np.ndarray) -> torch.Tensor: # Loads data in d4rl format, i.e. from Dict[str, np.array]. def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): if self._size != 0: - raise ValueError( - "Trying to load data into non-empty replay buffer" - ) + raise ValueError("Trying to load data into non-empty replay buffer") n_transitions = data["observations"].shape[0] if n_transitions > self._buffer_size: raise ValueError( @@ -151,24 +135,16 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): ) self._states[:n_transitions] = self._to_tensor(data["observations"]) self._actions[:n_transitions] = self._to_tensor(data["actions"]) - self._rewards[:n_transitions] = self._to_tensor( - data["rewards"][..., None] - ) - self._next_states[:n_transitions] = self._to_tensor( - data["next_observations"] - ) - self._dones[:n_transitions] = self._to_tensor( - data["terminals"][..., None] - ) + self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None]) + self._next_states[:n_transitions] = self._to_tensor(data["next_observations"]) + self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None]) self._size += n_transitions self._pointer = min(self._size, n_transitions) print(f"Dataset size: {n_transitions}") def sample(self, batch_size: int) -> TensorBatch: - indices = np.random.randint( - 0, min(self._size, self._pointer), size=batch_size - ) + indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size) states = self._states[indices] actions = self._actions[indices] rewards = self._rewards[indices] @@ -274,9 +250,7 @@ def __init__( super().__init__() n_dims = len(dims) if n_dims < 2: - raise ValueError( - "MLP requires at least two dims (input and output)" - ) + raise ValueError("MLP requires at least two dims (input and output)") layers = [] for i in range(n_dims - 2): @@ -317,14 +291,10 @@ def forward(self, obs: torch.Tensor) -> MultivariateNormal: @torch.no_grad() def act(self, state: np.ndarray, device: str = "cpu"): - state = torch.tensor( - state.reshape(1, -1), device=device, dtype=torch.float32 - ) + state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32) dist = self(state) action = dist.mean if not self.training else dist.sample() - action = torch.clamp( - self.max_action * action, -self.max_action, self.max_action - ) + action = torch.clamp(self.max_action * action, -self.max_action, self.max_action) return action.cpu().data.numpy().flatten() @@ -349,9 +319,7 @@ def forward(self, obs: torch.Tensor) -> torch.Tensor: @torch.no_grad() def act(self, state: np.ndarray, device: str = "cpu"): - state = torch.tensor( - state.reshape(1, -1), device=device, dtype=torch.float32 - ) + state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32) return ( torch.clamp( self(state) * self.max_action, @@ -383,16 +351,12 @@ def both( sa = torch.cat([state, action], 1) return self.q1(sa), self.q2(sa) - def forward( - self, state: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: + def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: return torch.min(*self.both(state, action)) class ValueFunction(nn.Module): - def __init__( - self, state_dim: int, hidden_dim: int = 256, n_hidden: int = 2 - ): + def __init__(self, state_dim: int, hidden_dim: int = 256, n_hidden: int = 2): super().__init__() dims = [state_dim, *([hidden_dim] * n_hidden), 1] self.v = MLP(dims, squeeze_output=True) @@ -426,9 +390,7 @@ def __init__( self.v_optimizer = v_optimizer self.q_optimizer = q_optimizer self.actor_optimizer = actor_optimizer - self.actor_lr_schedule = CosineAnnealingLR( - self.actor_optimizer, max_steps - ) + self.actor_lr_schedule = CosineAnnealingLR(self.actor_optimizer, max_steps) self.iql_tau = iql_tau self.beta = beta self.discount = discount @@ -460,10 +422,7 @@ def _update_q( terminals, log_dict, ): - targets = ( - rewards - + (1.0 - terminals.float()) * self.discount * next_v.detach() - ) + targets = rewards + (1.0 - terminals.float()) * self.discount * next_v.detach() qs = self.qf.both(observations, actions) q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs) log_dict["q_loss"] = q_loss.item() @@ -556,9 +515,7 @@ def train(config: TrainConfig): modify_reward(dataset, config.env) if config.normalize: - state_mean, state_std = compute_mean_std( - dataset["observations"], eps=1e-3 - ) + state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3) else: state_mean, state_std = 0, 1 @@ -582,9 +539,7 @@ def train(config: TrainConfig): if config.checkpoints_path is not None: print(f"Checkpoints path: {config.checkpoints_path}") os.makedirs(config.checkpoints_path, exist_ok=True) - with open( - os.path.join(config.checkpoints_path, "config.yaml"), "w" - ) as f: + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: pyrallis.dump(config, f) # Set seeds @@ -650,9 +605,7 @@ def train(config: TrainConfig): seed=config.seed, ) eval_score = eval_scores.mean() - normalized_eval_score = ( - env.get_normalized_score(eval_score) * 100.0 - ) + normalized_eval_score = env.get_normalized_score(eval_score) * 100.0 evaluations.append(normalized_eval_score) print("---------------------------------------") print( @@ -663,9 +616,7 @@ def train(config: TrainConfig): if config.checkpoints_path is not None: torch.save( trainer.state_dict(), - os.path.join( - config.checkpoints_path, f"checkpoint_{t}.pt" - ), + os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"), ) wandb.log( {"d4rl_normalized_score": normalized_eval_score}, diff --git a/algorithms/msg.py b/algorithms/msg.py index 97fc845..650b223 100644 --- a/algorithms/msg.py +++ b/algorithms/msg.py @@ -63,9 +63,7 @@ class TrainConfig: def __post_init__(self): self.name = f"{self.name}-{self.env_name}-{str(uuid.uuid4())[:8]}" if self.checkpoints_path is not None: - self.checkpoints_path = os.path.join( - self.checkpoints_path, self.name - ) + self.checkpoints_path = os.path.join(self.checkpoints_path, self.name) # general utils @@ -73,12 +71,8 @@ def __post_init__(self): def soft_update(target: nn.Module, source: nn.Module, tau: float): - for target_param, source_param in zip( - target.parameters(), source.parameters() - ): - target_param.data.copy_( - (1 - tau) * target_param.data + tau * source_param.data - ) + for target_param, source_param in zip(target.parameters(), source.parameters()): + target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data) def wandb_init(config: dict) -> None: @@ -104,9 +98,7 @@ def set_seed( torch.use_deterministic_algorithms(deterministic_torch) -def compute_mean_std( - states: np.ndarray, eps: float -) -> Tuple[np.ndarray, np.ndarray]: +def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]: mean = states.mean(0) std = states.std(0) + eps return mean, std @@ -125,9 +117,7 @@ def discount_cumsum(x, discount, include_first=True): else: disc_cumsum[-1] = 0 for t in reversed(range(x.shape[0] - 1)): - disc_cumsum[t] = ( - discount * x[t + 1] + discount * disc_cumsum[t + 1] - ) + disc_cumsum[t] = discount * x[t + 1] + discount * disc_cumsum[t + 1] return disc_cumsum @@ -196,9 +186,7 @@ def __init__( self._actions = torch.zeros( (buffer_size, action_dim), dtype=torch.float32, device=device ) - self._rewards = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._returns_to_go = torch.zeros( (buffer_size, 1), dtype=torch.float32, device=device ) @@ -217,9 +205,7 @@ def __init__( self._next_states = torch.zeros( (buffer_size, state_dim), dtype=torch.float32, device=device ) - self._dones = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._discount = discount self._device = device @@ -236,9 +222,7 @@ def compute_returns_to_go(self, data: np.ndarray) -> np.ndarray: data["rewards"][i] ) # - 1* self._action_dim* torch.log(1 / torch.sqrt(torch.tensor(2 * np.pi)))) if ( - data["terminals"][i] - or data["timeouts"][i] - or i == n_transitions - 1 + data["terminals"][i] or data["timeouts"][i] or i == n_transitions - 1 ): # TODO: Deal with incomplete trajectory case episode_returns_to_go = discount_cumsum( np.array(episode_rewards), self._discount @@ -259,9 +243,7 @@ def compute_returns_to_go(self, data: np.ndarray) -> np.ndarray: ) # Terminals/timeouts block next returns to go assert next_returns_to_go[0] == returns_to_go[1] - self._returns_to_go[:n_transitions] = self._to_tensor( - returns_to_go[..., None] - ) + self._returns_to_go[:n_transitions] = self._to_tensor(returns_to_go[..., None]) self._next_returns_to_go[:n_transitions] = self._to_tensor( next_returns_to_go[..., None] ) @@ -292,9 +274,7 @@ def compute_soft_returns_to_go( for i in range(n_transitions): episode_rewards.append(self._rewards[i].cpu().item()) - episode_entropy_bonuses.append( - self._entropy_bonuses[i].cpu().item() - ) + episode_entropy_bonuses.append(self._entropy_bonuses[i].cpu().item()) if self._dones[i] or i == n_transitions - 1: episode_returns_to_go = discount_cumsum( np.array(episode_rewards), self._discount @@ -331,9 +311,7 @@ def compute_soft_returns_to_go( # Loads data in d4rl format, i.e. from Dict[str, np.array]. def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): if self._size != 0: - raise ValueError( - "Trying to load data into non-empty replay buffer" - ) + raise ValueError("Trying to load data into non-empty replay buffer") n_transitions = data["observations"].shape[0] if n_transitions > self._buffer_size: raise ValueError( @@ -347,12 +325,8 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): self._states[:n_transitions] = self._to_tensor(data["observations"]) self._actions[:n_transitions] = self._to_tensor(data["actions"]) - self._rewards[:n_transitions] = self._to_tensor( - data["rewards"][..., None] - ) - self._next_states[:n_transitions] = self._to_tensor( - data["next_observations"] - ) + self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None]) + self._next_states[:n_transitions] = self._to_tensor(data["next_observations"]) self._dones[:n_transitions] = self._to_tensor( (data["terminals"] + data["timeouts"])[..., None] @@ -366,9 +340,7 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): print(f"Dataset size: {n_transitions}") def sample(self, batch_size: int) -> TensorBatch: - indices = np.random.randint( - 0, min(self._size, self._pointer), size=batch_size - ) + indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size) states = self._states[indices] actions = self._actions[indices] rewards = self._rewards[indices] @@ -397,17 +369,13 @@ def add_transition(self): # SAC Actor & Critic implementation class VectorizedLinear(nn.Module): - def __init__( - self, in_features: int, out_features: int, ensemble_size: int - ): + def __init__(self, in_features: int, out_features: int, ensemble_size: int): super().__init__() self.in_features = in_features self.out_features = out_features self.ensemble_size = ensemble_size - self.weight = nn.Parameter( - torch.empty(ensemble_size, in_features, out_features) - ) + self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features)) self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features)) self.reset_parameters() @@ -494,15 +462,11 @@ def forward( if need_log_prob: # change of variables formula (SAC paper, appendix C, eq 21) log_prob = policy_dist.log_prob(action).sum(axis=-1) - log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum( - axis=-1 - ) + log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum(axis=-1) return tanh_action * self.max_action, log_prob - def log_prob( - self, state: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: + def log_prob(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: hidden = self.trunk(state) mu, log_sigma = self.mu(hidden), self.log_sigma(hidden) # print(mu.mean()) @@ -570,9 +534,7 @@ def __init__( self.num_critics = num_critics - def forward( - self, state: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: + def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: # [..., batch_size, state_dim + action_dim] state_action = torch.cat([state, action], dim=-1) if state_action.dim() != 3: @@ -630,27 +592,21 @@ def __init__( self.log_alpha = torch.tensor( [0.0], dtype=torch.float32, device=self.device, requires_grad=True ) - self.alpha_optimizer = torch.optim.Adam( - [self.log_alpha], lr=alpha_learning_rate - ) + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_learning_rate) self.alpha = self.log_alpha.exp().detach() def _alpha_loss(self, state: torch.Tensor) -> torch.Tensor: with torch.no_grad(): action, action_log_prob = self.actor(state, need_log_prob=True) - loss = ( - -self.log_alpha * (action_log_prob + self.target_entropy) - ).mean() + loss = (-self.log_alpha * (action_log_prob + self.target_entropy)).mean() return loss def _actor_loss( self, state: torch.Tensor, action: torch.Tensor ) -> Tuple[torch.Tensor, float, float]: - pi, log_pi = self.actor( - state, need_log_prob=True - ) # , deterministic=True + pi, log_pi = self.actor(state, need_log_prob=True) # , deterministic=True # log_prob_action = self.actor.log_prob(state, action) @@ -691,10 +647,7 @@ def _critic_loss( # q_next = q_nexts.mean(0) - 2 * q_nexts.std(0) q_next = q_next - self.alpha * next_action_log_prob.unsqueeze(0) assert q_next.shape[1] == done.shape[0] == reward.shape[0] - q_target = ( - reward.view(1, -1) - + self.gamma * (1 - done.view(1, -1)) * q_next - ) + q_target = reward.view(1, -1) + self.gamma * (1 - done.view(1, -1)) * q_next q_values = self.critic(state, action) # [ensemble_size, batch_size] - [1, batch_size] @@ -703,9 +656,7 @@ def _critic_loss( pi, _ = self.actor(state, need_log_prob=False) q_policy_values = self.critic(state, pi) - support_regulariser = ( - (q_policy_values - q_values).mean(dim=1).sum(dim=0) - ) + support_regulariser = (q_policy_values - q_values).mean(dim=1).sum(dim=0) # loss = (1 / critic_loss.abs().mean().detach()) * loss = critic_loss + self.eta * support_regulariser @@ -768,9 +719,7 @@ def pretrain_actorcritic(self, batch: TensorBatch) -> Dict[str, float]: # self.alpha = self.log_alpha.exp().detach() # Compute actor loss - pi, action_log_prob = self.actor( - state, deterministic=True, need_log_prob=True - ) + pi, action_log_prob = self.actor(state, deterministic=True, need_log_prob=True) actor_loss = F.mse_loss( pi, action ) + self.alpha * self.actor.action_dim * torch.log( @@ -827,9 +776,7 @@ def pretrain_critic(self, batch: TensorBatch) -> Dict[str, float]: # Compute critic loss q = self.critic(state, action).mean(dim=0) diversity_loss = self._critic_diversity_loss(state, action) - critic_loss = ( - F.mse_loss(q, return_to_go.squeeze()) + self.eta * diversity_loss - ) + critic_loss = F.mse_loss(q, return_to_go.squeeze()) + self.eta * diversity_loss log_dict["critic_loss"] = critic_loss.item() # Optimize the critic self.pretrain_critic_optimizer.zero_grad() @@ -898,10 +845,7 @@ def pretrain_soft_critic( # self.target_critic(next_state, next_action).min(0).values.unsqueeze(-1) # ) - q_next = ( - next_return_to_go - - self.alpha * next_action_log_prob.unsqueeze(-1) - ) + q_next = next_return_to_go - self.alpha * next_action_log_prob.unsqueeze(-1) q_target = reward + self.gamma * (1 - done) * q_next q_values = ( @@ -911,9 +855,7 @@ def pretrain_soft_critic( pi, _ = self.actor(state, need_log_prob=False) q_policy_values = self.critic(state, pi) - support_regulariser = ( - (q_policy_values - q_values).mean(dim=1).sum(dim=0) - ) + support_regulariser = (q_policy_values - q_values).mean(dim=1).sum(dim=0) # # [ensemble_size, batch_size] - [1, batch_size] critic_loss = ( ((q_values - q_target.view(1, -1)) ** 2).mean(dim=1).sum(dim=0) @@ -939,13 +881,9 @@ def pretrain_soft_critic( # for logging, Q-ensemble std estimate with the random actions: # a ~ U[-max_action, max_action] max_action = self.actor.max_action - random_actions = -max_action + 2 * max_action * torch.rand_like( - action - ) + random_actions = -max_action + 2 * max_action * torch.rand_like(action) - q_random_std = ( - self.critic(state, random_actions).std(0).mean().item() - ) + q_random_std = self.critic(state, random_actions).std(0).mean().item() log_dict = { "critic_loss": critic_loss.item(), @@ -977,17 +915,13 @@ def update(self, batch: TensorBatch) -> Dict[str, float]: self.alpha = self.log_alpha.exp().detach() # Actor update - actor_loss, actor_batch_entropy, q_policy_std = self._actor_loss( - state, action - ) + actor_loss, actor_batch_entropy, q_policy_std = self._actor_loss(state, action) self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # Critic update - critic_loss = self._critic_loss( - state, action, reward, next_state, done - ) + critic_loss = self._critic_loss(state, action, reward, next_state, done) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() @@ -998,13 +932,9 @@ def update(self, batch: TensorBatch) -> Dict[str, float]: # for logging, Q-ensemble std estimate with the random actions: # a ~ U[-max_action, max_action] max_action = self.actor.max_action - random_actions = -max_action + 2 * max_action * torch.rand_like( - action - ) + random_actions = -max_action + 2 * max_action * torch.rand_like(action) - q_random_std = ( - self.critic(state, random_actions).std(0).mean().item() - ) + q_random_std = self.critic(state, random_actions).std(0).mean().item() update_info = { "alpha_loss": alpha_loss.item(), @@ -1014,12 +944,10 @@ def update(self, batch: TensorBatch) -> Dict[str, float]: "alpha": self.alpha.item(), "q_policy_std": q_policy_std, "q_random_std": q_random_std, - "actor_lr": [ - group["lr"] for group in self.actor_optimizer.param_groups - ][0], - "critic_lr": [ - group["lr"] for group in self.critic_optimizer.param_groups - ][0], + "actor_lr": [group["lr"] for group in self.actor_optimizer.param_groups][0], + "critic_lr": [group["lr"] for group in self.critic_optimizer.param_groups][ + 0 + ], } return update_info @@ -1109,9 +1037,7 @@ def train(config: TrainConfig): if config.normalize_reward: modify_reward(d4rl_dataset, config.env_name) - state_mean, state_std = compute_mean_std( - d4rl_dataset["observations"], eps=1e-3 - ) + state_mean, state_std = compute_mean_std(d4rl_dataset["observations"], eps=1e-3) d4rl_dataset["observations"] = normalize_states( d4rl_dataset["observations"], state_mean, state_std @@ -1120,9 +1046,7 @@ def train(config: TrainConfig): d4rl_dataset["next_observations"] = np.roll( d4rl_dataset["observations"], shift=-1, axis=0 ) # Terminals/timeouts block next observations - print( - "Loaded next state observations from current state observations." - ) + print("Loaded next state observations from current state observations.") d4rl_dataset["next_observations"] = normalize_states( d4rl_dataset["next_observations"], state_mean, state_std @@ -1151,9 +1075,7 @@ def train(config: TrainConfig): pretrain_actor_optimizer = torch.optim.Adam( actor.parameters(), lr=5 * config.actor_learning_rate ) - actor_optimizer = torch.optim.Adam( - actor.parameters(), lr=config.actor_learning_rate - ) + actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_learning_rate) # actor_scheduler = torch.optim.lr_scheduler.LinearLR( # actor_optimizer, start_factor=0.01, total_iters=500 # ) @@ -1201,18 +1123,14 @@ def train(config: TrainConfig): if config.checkpoints_path is not None: print(f"Checkpoints path: {config.checkpoints_path}") os.makedirs(config.checkpoints_path, exist_ok=True) - with open( - os.path.join(config.checkpoints_path, "config.yaml"), "w" - ) as f: + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: pyrallis.dump(config, f) total_updates = 0.0 # reset_optimisers = True for epoch in trange(config.num_epochs, desc="Training"): # training - for _ in trange( - config.num_updates_on_epoch, desc="Epoch", leave=False - ): + for _ in trange(config.num_updates_on_epoch, desc="Epoch", leave=False): batch = buffer.sample(config.batch_size) if config.pretrain is not None: if epoch <= config.pretrain_epochs: @@ -1232,9 +1150,7 @@ def train(config: TrainConfig): alpha=trainer.alpha, actor=trainer.actor, ) - print( - "Soft returns to go loaded for BC actor!" - ) + print("Soft returns to go loaded for BC actor!") assert buffer._soft_returns_loaded == True update_info = trainer.pretrain_soft_critic( batch, epoch, config.pretrain_epochs @@ -1246,9 +1162,7 @@ def train(config: TrainConfig): alpha=trainer.alpha, actor=trainer.actor, ) - print( - "Soft returns to go loaded for initialised actor!" - ) + print("Soft returns to go loaded for initialised actor!") assert buffer._soft_returns_loaded == True update_info = trainer.pretrain_soft_critic( batch, epoch, config.pretrain_epochs @@ -1299,15 +1213,9 @@ def train(config: TrainConfig): "epoch": epoch, } if hasattr(eval_env, "get_normalized_score"): - normalized_score = ( - eval_env.get_normalized_score(eval_returns) * 100.0 - ) - eval_log["eval/normalized_score_mean"] = np.mean( - normalized_score - ) - eval_log["eval/normalized_score_std"] = np.std( - normalized_score - ) + normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0 + eval_log["eval/normalized_score_mean"] = np.mean(normalized_score) + eval_log["eval/normalized_score_std"] = np.std(normalized_score) wandb.log(eval_log) diff --git a/algorithms/sac_n.py b/algorithms/sac_n.py index d367e76..4e341c6 100644 --- a/algorithms/sac_n.py +++ b/algorithms/sac_n.py @@ -63,9 +63,7 @@ class TrainConfig: def __post_init__(self): self.name = f"{self.name}-{self.env_name}-{str(uuid.uuid4())[:8]}" if self.checkpoints_path is not None: - self.checkpoints_path = os.path.join( - self.checkpoints_path, self.name - ) + self.checkpoints_path = os.path.join(self.checkpoints_path, self.name) # general utils @@ -73,12 +71,8 @@ def __post_init__(self): def soft_update(target: nn.Module, source: nn.Module, tau: float): - for target_param, source_param in zip( - target.parameters(), source.parameters() - ): - target_param.data.copy_( - (1 - tau) * target_param.data + tau * source_param.data - ) + for target_param, source_param in zip(target.parameters(), source.parameters()): + target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data) def wandb_init(config: dict) -> None: @@ -105,9 +99,7 @@ def set_seed( torch.use_deterministic_algorithms(deterministic_torch) -def compute_mean_std( - states: np.ndarray, eps: float -) -> Tuple[np.ndarray, np.ndarray]: +def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]: mean = states.mean(0) std = states.std(0) + eps return mean, std @@ -162,18 +154,14 @@ def __init__( self._actions = torch.zeros( (buffer_size, action_dim), dtype=torch.float32, device=device ) - self._rewards = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._returns_to_go = torch.zeros( (buffer_size, 1), dtype=torch.float32, device=device ) self._next_states = torch.zeros( (buffer_size, state_dim), dtype=torch.float32, device=device ) - self._dones = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._discount = discount self._device = device @@ -183,9 +171,7 @@ def _to_tensor(self, data: np.ndarray) -> torch.Tensor: # Loads data in d4rl format, i.e. from Dict[str, np.array]. def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): if self._size != 0: - raise ValueError( - "Trying to load data into non-empty replay buffer" - ) + raise ValueError("Trying to load data into non-empty replay buffer") n_transitions = data["observations"].shape[0] if n_transitions > self._buffer_size: raise ValueError( @@ -196,11 +182,7 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): returns_to_go = [] for i in range(n_transitions): - if ( - data["terminals"][i] - or data["timeouts"][i] - or i == n_transitions - 1 - ): + if data["terminals"][i] or data["timeouts"][i] or i == n_transitions - 1: episode_rewards.append(data["rewards"][i]) episode_returns_to_go = discount_cumsum( np.array(episode_rewards), self._discount @@ -219,15 +201,11 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): self._states[:n_transitions] = self._to_tensor(data["observations"]) self._actions[:n_transitions] = self._to_tensor(data["actions"]) - self._rewards[:n_transitions] = self._to_tensor( - data["rewards"][..., None] - ) + self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None]) self._returns_to_go[:n_transitions] = self._to_tensor( data["returns_to_go"][..., None] ) - self._next_states[:n_transitions] = self._to_tensor( - data["next_observations"] - ) + self._next_states[:n_transitions] = self._to_tensor(data["next_observations"]) self._dones[:n_transitions] = self._to_tensor( (data["terminals"] + data["timeouts"])[..., None] ) @@ -237,9 +215,7 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): print(f"Dataset size: {n_transitions}") def sample(self, batch_size: int) -> TensorBatch: - indices = np.random.randint( - 0, min(self._size, self._pointer), size=batch_size - ) + indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size) states = self._states[indices] actions = self._actions[indices] rewards = self._rewards[indices] @@ -255,17 +231,13 @@ def add_transition(self): # SAC Actor & Critic implementation class VectorizedLinear(nn.Module): - def __init__( - self, in_features: int, out_features: int, ensemble_size: int - ): + def __init__(self, in_features: int, out_features: int, ensemble_size: int): super().__init__() self.in_features = in_features self.out_features = out_features self.ensemble_size = ensemble_size - self.weight = nn.Parameter( - torch.empty(ensemble_size, in_features, out_features) - ) + self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features)) self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features)) self.reset_parameters() @@ -344,9 +316,7 @@ def forward( if need_log_prob: # change of variables formula (SAC paper, appendix C, eq 21) log_prob = policy_dist.log_prob(action).sum(axis=-1) - log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum( - axis=-1 - ) + log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum(axis=-1) return tanh_action * self.max_action, log_prob @@ -388,9 +358,7 @@ def __init__( self.num_critics = num_critics - def forward( - self, state: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: + def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: # [batch_size, state_dim + action_dim] state_action = torch.cat([state, action], dim=-1) # [num_critics, batch_size, state_dim + action_dim] @@ -432,24 +400,18 @@ def __init__( self.log_alpha = torch.tensor( [0.0], dtype=torch.float32, device=self.device, requires_grad=True ) - self.alpha_optimizer = torch.optim.Adam( - [self.log_alpha], lr=alpha_learning_rate - ) + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_learning_rate) self.alpha = self.log_alpha.exp().detach() def _alpha_loss(self, state: torch.Tensor) -> torch.Tensor: with torch.no_grad(): action, action_log_prob = self.actor(state, need_log_prob=True) - loss = ( - -self.log_alpha * (action_log_prob + self.target_entropy) - ).mean() + loss = (-self.log_alpha * (action_log_prob + self.target_entropy)).mean() return loss - def _actor_loss( - self, state: torch.Tensor - ) -> Tuple[torch.Tensor, float, float]: + def _actor_loss(self, state: torch.Tensor) -> Tuple[torch.Tensor, float, float]: action, action_log_prob = self.actor(state, need_log_prob=True) q_value_dist = self.critic(state, action) assert q_value_dist.shape[0] == self.critic.num_critics @@ -551,9 +513,7 @@ def update(self, batch: TensorBatch) -> Dict[str, float]: self.actor_optimizer.step() # Critic update - critic_loss = self._critic_loss( - state, action, reward, next_state, done - ) + critic_loss = self._critic_loss(state, action, reward, next_state, done) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() @@ -564,13 +524,9 @@ def update(self, batch: TensorBatch) -> Dict[str, float]: # for logging, Q-ensemble std estimate with the random actions: # a ~ U[-max_action, max_action] max_action = self.actor.max_action - random_actions = -max_action + 2 * max_action * torch.rand_like( - action - ) + random_actions = -max_action + 2 * max_action * torch.rand_like(action) - q_random_std = ( - self.critic(state, random_actions).std(0).mean().item() - ) + q_random_std = self.critic(state, random_actions).std(0).mean().item() update_info = { "alpha_loss": alpha_loss.item(), @@ -666,9 +622,7 @@ def train(config: TrainConfig): if config.normalize_reward: modify_reward(d4rl_dataset, config.env) - state_mean, state_std = compute_mean_std( - d4rl_dataset["observations"], eps=1e-3 - ) + state_mean, state_std = compute_mean_std(d4rl_dataset["observations"], eps=1e-3) d4rl_dataset["observations"] = normalize_states( d4rl_dataset["observations"], state_mean, state_std @@ -690,9 +644,7 @@ def train(config: TrainConfig): # Actor & Critic setup actor = Actor(state_dim, action_dim, config.hidden_dim, config.max_action) actor.to(config.device) - actor_optimizer = torch.optim.Adam( - actor.parameters(), lr=config.actor_learning_rate - ) + actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_learning_rate) critic = VectorizedCritic( state_dim, action_dim, config.hidden_dim, config.num_critics ) @@ -715,17 +667,13 @@ def train(config: TrainConfig): if config.checkpoints_path is not None: print(f"Checkpoints path: {config.checkpoints_path}") os.makedirs(config.checkpoints_path, exist_ok=True) - with open( - os.path.join(config.checkpoints_path, "config.yaml"), "w" - ) as f: + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: pyrallis.dump(config, f) total_updates = 0.0 for epoch in trange(config.num_epochs, desc="Training"): # training - for _ in trange( - config.num_updates_on_epoch, desc="Epoch", leave=False - ): + for _ in trange(config.num_updates_on_epoch, desc="Epoch", leave=False): batch = buffer.sample(config.batch_size) if config.pretrain is not None: if epoch <= config.pretrain_epochs: @@ -766,15 +714,9 @@ def train(config: TrainConfig): "epoch": epoch, } if hasattr(eval_env, "get_normalized_score"): - normalized_score = ( - eval_env.get_normalized_score(eval_returns) * 100.0 - ) - eval_log["eval/normalized_score_mean"] = np.mean( - normalized_score - ) - eval_log["eval/normalized_score_std"] = np.std( - normalized_score - ) + normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0 + eval_log["eval/normalized_score_mean"] = np.mean(normalized_score) + eval_log["eval/normalized_score_std"] = np.std(normalized_score) wandb.log(eval_log) diff --git a/algorithms/td3_bc.py b/algorithms/td3_bc.py index 34452ef..b838616 100644 --- a/algorithms/td3_bc.py +++ b/algorithms/td3_bc.py @@ -45,9 +45,7 @@ class TrainConfig: discount: float = 0.99 # Discount ffor expl_noise: float = 0.1 # Std of Gaussian exploration noise tau: float = 0.005 # Target network update rate - policy_noise: float = ( - 0.2 # Noise added to target actor during critic update - ) + policy_noise: float = 0.2 # Noise added to target actor during critic update noise_clip: float = 0.5 # Range to clip target actor noise policy_freq: int = 2 # Frequency of delayed actor updates # TD3 + BC @@ -56,12 +54,8 @@ class TrainConfig: normalize_reward: bool = False # Normalize reward pretrain: Optional[str] = None # BC or AC pretrain_steps: int = 10000 # Number of pretraining steps - td_component: float = ( - -1.0 - ) # Proportion of TD to use (rather than MC) in pretraining - pretrain_cql_regulariser: float = ( - -1.0 - ) # CQL regularisation for pretraining + td_component: float = -1.0 # Proportion of TD to use (rather than MC) in pretraining + pretrain_cql_regulariser: float = -1.0 # CQL regularisation for pretraining cql_regulariser: float = -1.0 # CQL regularisation in training cql_n_actions: int = 10 # Number of actions to sample for CQL actor_LN: bool = True # Use LayerNorm in actor @@ -83,17 +77,11 @@ def __post_init__(self): def soft_update(target: nn.Module, source: nn.Module, tau: float): - for target_param, source_param in zip( - target.parameters(), source.parameters() - ): - target_param.data.copy_( - (1 - tau) * target_param.data + tau * source_param.data - ) + for target_param, source_param in zip(target.parameters(), source.parameters()): + target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data) -def compute_mean_std( - states: np.ndarray, eps: float -) -> Tuple[np.ndarray, np.ndarray]: +def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]: mean = states.mean(0) std = states.std(0) + eps return mean, std @@ -152,18 +140,14 @@ def __init__( self._actions = torch.zeros( (buffer_size, action_dim), dtype=torch.float32, device=device ) - self._rewards = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._returns_to_go = torch.zeros( (buffer_size, 1), dtype=torch.float32, device=device ) self._next_states = torch.zeros( (buffer_size, state_dim), dtype=torch.float32, device=device ) - self._dones = torch.zeros( - (buffer_size, 1), dtype=torch.float32, device=device - ) + self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) self._discount = discount self._device = device @@ -177,15 +161,9 @@ def compute_returns_to_go(self, data: np.ndarray) -> np.ndarray: for i in range(n_transitions): episode_rewards.append(data["rewards"][i]) - if ( - data["terminals"][i] - or data["timeouts"][i] - or i == n_transitions - 1 - ): + if data["terminals"][i] or data["timeouts"][i] or i == n_transitions - 1: if data["timeouts"][i] or i == n_transitions - 1: - episode_rewards[-1] = episode_rewards[-1] / ( - 1 - self._discount - ) + episode_rewards[-1] = episode_rewards[-1] / (1 - self._discount) episode_returns_to_go = discount_cumsum( np.array(episode_rewards), self._discount ) @@ -204,9 +182,7 @@ def compute_returns_to_go(self, data: np.ndarray) -> np.ndarray: # Loads data in d4rl format, i.e. from Dict[str, np.array]. def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): if self._size != 0: - raise ValueError( - "Trying to load data into non-empty replay buffer" - ) + raise ValueError("Trying to load data into non-empty replay buffer") n_transitions = data["observations"].shape[0] if n_transitions > self._buffer_size: raise ValueError( @@ -215,12 +191,8 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): self._states[:n_transitions] = self._to_tensor(data["observations"]) self._actions[:n_transitions] = self._to_tensor(data["actions"]) - self._rewards[:n_transitions] = self._to_tensor( - data["rewards"][..., None] - ) - self._next_states[:n_transitions] = self._to_tensor( - data["next_observations"] - ) + self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None]) + self._next_states[:n_transitions] = self._to_tensor(data["next_observations"]) self._dones[:n_transitions] = self._to_tensor( (data["terminals"] + data["timeouts"])[..., None] ) @@ -235,9 +207,7 @@ def load_d4rl_dataset(self, data: Dict[str, np.ndarray]): print(f"Dataset size: {n_transitions}") def sample(self, batch_size: int) -> TensorBatch: - indices = np.random.randint( - 0, min(self._size, self._pointer), size=batch_size - ) + indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size) states = self._states[indices] actions = self._actions[indices] rewards = self._rewards[indices] @@ -279,7 +249,13 @@ def wandb_init(config: dict) -> None: @torch.no_grad() def eval_actor( - env: gym.Env, actor: nn.Module, device: str, n_episodes: int, seed: int, render: bool, name: str, + env: gym.Env, + actor: nn.Module, + device: str, + n_episodes: int, + seed: int, + render: bool, + name: str, ) -> np.ndarray: env.seed(seed) actor.eval() @@ -366,19 +342,13 @@ def __init__( self.net = nn.Sequential( nn.Linear(state_dim, 256), - nn.LayerNorm(256, elementwise_affine=False) - if actor_LN - else nn.Identity(), + nn.LayerNorm(256, elementwise_affine=False) if actor_LN else nn.Identity(), nn.ReLU(), nn.Linear(256, 256), - nn.LayerNorm(256, elementwise_affine=False) - if actor_LN - else nn.Identity(), + nn.LayerNorm(256, elementwise_affine=False) if actor_LN else nn.Identity(), nn.ReLU(), nn.Linear(256, 256), - nn.LayerNorm(256, elementwise_affine=False) - if actor_LN - else nn.Identity(), + nn.LayerNorm(256, elementwise_affine=False) if actor_LN else nn.Identity(), nn.ReLU(), nn.Linear(256, action_dim), nn.Tanh(), @@ -391,35 +361,25 @@ def forward(self, state: torch.Tensor) -> torch.Tensor: @torch.no_grad() def act(self, state: np.ndarray, device: str = "cpu") -> np.ndarray: - state = torch.tensor( - state.reshape(1, -1), device=device, dtype=torch.float32 - ) + state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32) return self(state).cpu().data.numpy().flatten() class Critic(nn.Module): - def __init__( - self, state_dim: int, action_dim: int, critic_LN: bool = True - ): + def __init__(self, state_dim: int, action_dim: int, critic_LN: bool = True): super(Critic, self).__init__() self.net = nn.Sequential( nn.Linear(state_dim + action_dim, 256), - nn.LayerNorm(256, elementwise_affine=False) - if critic_LN - else nn.Identity(), + nn.LayerNorm(256, elementwise_affine=False) if critic_LN else nn.Identity(), nn.ReLU(), nn.Linear(256, 256), - nn.LayerNorm(256, elementwise_affine=False) - if critic_LN - else nn.Identity(), + nn.LayerNorm(256, elementwise_affine=False) if critic_LN else nn.Identity(), nn.ReLU(), nn.Linear(256, 1), ) - def forward( - self, state: torch.Tensor, action: torch.Tensor - ) -> torch.Tensor: + def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: sa = torch.cat([state, action], 1) return self.net(sa) @@ -506,9 +466,7 @@ def train(self, batch: TensorBatch) -> Dict[str, float]: current_q2 = self.critic_2(state, action) # Compute critic loss - critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss( - current_q2, target_q - ) + critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) log_dict["critic_loss"] = critic_loss.item() if self.cql_regulariser > 0.0: @@ -518,24 +476,20 @@ def train(self, batch: TensorBatch) -> Dict[str, float]: requires_grad=False, ).uniform_(-1, 1) repeated_state = state.repeat(self.cql_n_actions, 1) - q1_random_values = self.critic_1( - repeated_state, random_actions - ).reshape(self.cql_n_actions, -1, 1) - q2_random_values = self.critic_2( - repeated_state, random_actions - ).reshape(self.cql_n_actions, -1, 1) + q1_random_values = self.critic_1(repeated_state, random_actions).reshape( + self.cql_n_actions, -1, 1 + ) + q2_random_values = self.critic_2(repeated_state, random_actions).reshape( + self.cql_n_actions, -1, 1 + ) cql_regularisation = ( torch.logsumexp(q1_random_values, dim=0) - current_q1 - ).mean() + ( - torch.logsumexp(q2_random_values, dim=0) - current_q2 - ).mean() + ).mean() + (torch.logsumexp(q2_random_values, dim=0) - current_q2).mean() log_dict["support_regulariser"] = cql_regularisation.item() critic_loss = ( critic_loss / critic_loss.detach() - + self.cql_regulariser - * cql_regularisation - / cql_regularisation.detach() + + self.cql_regulariser * cql_regularisation / cql_regularisation.detach() ) # Optimize the critic @@ -622,9 +576,7 @@ def pretrain_actorcritic(self, batch: TensorBatch) -> Dict[str, float]: target_q = torch.min(target_q1, target_q2) target_q = reward + (1 - done) * self.discount * target_q - TD_loss = F.mse_loss(current_q1, target_q) + F.mse_loss( - current_q2, target_q - ) + TD_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) log_dict["TD_loss"] = TD_loss.item() critic_loss = ( 1 - self.td_component @@ -637,18 +589,16 @@ def pretrain_actorcritic(self, batch: TensorBatch) -> Dict[str, float]: requires_grad=False, ).uniform_(-1, 1) repeated_state = state.repeat(self.cql_n_actions, 1) - q1_policy_values = self.critic_1( - repeated_state, random_actions - ).reshape(self.cql_n_actions, -1, 1) - q2_policy_values = self.critic_2( - repeated_state, random_actions - ).reshape(self.cql_n_actions, -1, 1) + q1_policy_values = self.critic_1(repeated_state, random_actions).reshape( + self.cql_n_actions, -1, 1 + ) + q2_policy_values = self.critic_2(repeated_state, random_actions).reshape( + self.cql_n_actions, -1, 1 + ) cql_regulariser = ( torch.logsumexp(q1_policy_values, dim=0) - current_q1 - ).mean() + ( - torch.logsumexp(q2_policy_values, dim=0) - current_q2 - ).mean() + ).mean() + (torch.logsumexp(q2_policy_values, dim=0) - current_q2).mean() log_dict["support_regulariser"] = cql_regulariser.item() critic_loss = ( @@ -686,15 +636,11 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]): self.critic_1.load_state_dict(state_dict["critic_1"]) - self.critic_1_optimizer.load_state_dict( - state_dict["critic_1_optimizer"] - ) + self.critic_1_optimizer.load_state_dict(state_dict["critic_1_optimizer"]) self.critic_1_target = copy.deepcopy(self.critic_1) self.critic_2.load_state_dict(state_dict["critic_2"]) - self.critic_2_optimizer.load_state_dict( - state_dict["critic_2_optimizer"] - ) + self.critic_2_optimizer.load_state_dict(state_dict["critic_2_optimizer"]) self.critic_2_target = copy.deepcopy(self.critic_2) self.actor.load_state_dict(state_dict["actor"]) @@ -717,9 +663,7 @@ def train(config: TrainConfig): modify_reward(dataset, config.env) if config.normalize: - state_mean, state_std = compute_mean_std( - dataset["observations"], eps=1e-3 - ) + state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3) else: state_mean, state_std = 0, 1 @@ -731,9 +675,7 @@ def train(config: TrainConfig): dataset["next_observations"] = np.roll( dataset["observations"], shift=-1, axis=0 ) # Terminals/timeouts block next observations - print( - "Loaded next state observations from current state observations." - ) + print("Loaded next state observations from current state observations.") dataset["next_observations"] = normalize_states( dataset["next_observations"], state_mean, state_std @@ -753,27 +695,19 @@ def train(config: TrainConfig): if config.checkpoints_path is not None: print(f"Checkpoints path: {config.checkpoints_path}") os.makedirs(config.checkpoints_path, exist_ok=True) - with open( - os.path.join(config.checkpoints_path, "config.yaml"), "w" - ) as f: + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: pyrallis.dump(config, f) # Set seeds seed = config.seed set_seed(seed, env) - actor = Actor(state_dim, action_dim, max_action, config.actor_LN).to( - config.device - ) + actor = Actor(state_dim, action_dim, max_action, config.actor_LN).to(config.device) actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4) - critic_1 = Critic(state_dim, action_dim, config.critic_LN).to( - config.device - ) + critic_1 = Critic(state_dim, action_dim, config.critic_LN).to(config.device) critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=3e-4) - critic_2 = Critic(state_dim, action_dim, config.critic_LN).to( - config.device - ) + critic_2 = Critic(state_dim, action_dim, config.critic_LN).to(config.device) critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=3e-4) kwargs = { @@ -828,18 +762,12 @@ def train(config: TrainConfig): elif config.pretrain == "C": log_dict = trainer.pretrain_critic(batch) else: - raise ValueError( - f"Pretrain type {config.pretrain} not recognised." - ) + raise ValueError(f"Pretrain type {config.pretrain} not recognised.") else: if t == config.pretrain_steps: with torch.no_grad(): - trainer.pretrained_critic_1 = copy.deepcopy( - trainer.critic_1 - ) - trainer.pretrained_critic_2 = copy.deepcopy( - trainer.critic_2 - ) + trainer.pretrained_critic_1 = copy.deepcopy(trainer.critic_1) + trainer.pretrained_critic_2 = copy.deepcopy(trainer.critic_2) log_dict = trainer.train(batch) else: log_dict = trainer.train(batch) @@ -863,15 +791,9 @@ def train(config: TrainConfig): "epoch": int((t + 1) / 1000), } if hasattr(env, "get_normalized_score"): - normalized_score = ( - env.get_normalized_score(eval_returns) * 100.0 - ) - eval_log["eval/normalized_score_mean"] = np.mean( - normalized_score - ) - eval_log["eval/normalized_score_std"] = np.std( - normalized_score - ) + normalized_score = env.get_normalized_score(eval_returns) * 100.0 + eval_log["eval/normalized_score_mean"] = np.mean(normalized_score) + eval_log["eval/normalized_score_std"] = np.std(normalized_score) wandb.log(eval_log) print("---------------------------------------")