Skip to content

Commit

Permalink
Fix codestyle issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamJelley committed Jun 18, 2024
1 parent 59f76eb commit f2aa673
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 911 deletions.
112 changes: 31 additions & 81 deletions algorithms/any_percent_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,11 @@ def __post_init__(self):


def soft_update(target: nn.Module, source: nn.Module, tau: float):
for target_param, source_param in zip(
target.parameters(), source.parameters()
):
target_param.data.copy_(
(1 - tau) * target_param.data + tau * source_param.data
)
for target_param, source_param in zip(target.parameters(), source.parameters()):
target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def compute_mean_std(
states: np.ndarray, eps: float
) -> Tuple[np.ndarray, np.ndarray]:
def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
mean = states.mean(0)
std = states.std(0) + eps
return mean, std
Expand Down Expand Up @@ -116,15 +110,11 @@ def __init__(
self._actions = torch.zeros(
(buffer_size, action_dim), dtype=torch.float32, device=device
)
self._rewards = torch.zeros(
(buffer_size, 1), dtype=torch.float32, device=device
)
self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self._next_states = torch.zeros(
(buffer_size, state_dim), dtype=torch.float32, device=device
)
self._dones = torch.zeros(
(buffer_size, 1), dtype=torch.float32, device=device
)
self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self._device = device

def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
Expand All @@ -133,34 +123,24 @@ def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
# Loads data in d4rl format, i.e. from Dict[str, np.array].
def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
if self._size != 0:
raise ValueError(
"Trying to load data into non-empty replay buffer"
)
raise ValueError("Trying to load data into non-empty replay buffer")
n_transitions = data["observations"].shape[0]
if n_transitions > self._buffer_size:
raise ValueError(
"Replay buffer is smaller than the dataset you are trying to load!"
)
self._states[:n_transitions] = self._to_tensor(data["observations"])
self._actions[:n_transitions] = self._to_tensor(data["actions"])
self._rewards[:n_transitions] = self._to_tensor(
data["rewards"][..., None]
)
self._next_states[:n_transitions] = self._to_tensor(
data["next_observations"]
)
self._dones[:n_transitions] = self._to_tensor(
data["terminals"][..., None]
)
self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
self._size += n_transitions
self._pointer = min(self._size, n_transitions)

print(f"Dataset size: {n_transitions}")

def sample(self, batch_size: int) -> TensorBatch:
indices = np.random.randint(
0, min(self._size, self._pointer), size=batch_size
)
indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
states = self._states[indices]
actions = self._actions[indices]
rewards = self._rewards[indices]
Expand Down Expand Up @@ -248,9 +228,7 @@ def keep_best_trajectories(
cur_ids = []
cur_return = 0
reward_scale = 1.0
for i, (reward, done) in enumerate(
zip(dataset["rewards"], dataset["terminals"])
):
for i, (reward, done) in enumerate(zip(dataset["rewards"], dataset["terminals"])):
cur_return += reward_scale * reward
cur_ids.append(i)
reward_scale *= discount
Expand Down Expand Up @@ -291,19 +269,13 @@ def __init__(

self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.LayerNorm(256, elementwise_affine=False)
if actor_LN
else nn.Identity(),
nn.LayerNorm(256, elementwise_affine=False) if actor_LN else nn.Identity(),
nn.ReLU(),
nn.Linear(256, 256),
nn.LayerNorm(256, elementwise_affine=False)
if actor_LN
else nn.Identity(),
nn.LayerNorm(256, elementwise_affine=False) if actor_LN else nn.Identity(),
nn.ReLU(),
nn.Linear(256, 256),
nn.LayerNorm(256, elementwise_affine=False)
if actor_LN
else nn.Identity(),
nn.LayerNorm(256, elementwise_affine=False) if actor_LN else nn.Identity(),
nn.ReLU(),
# nn.Linear(256, action_dim),
# nn.Tanh(),
Expand All @@ -313,9 +285,7 @@ def __init__(

self.max_action = max_action

def forward(
self, state: torch.Tensor, deterministic=False
) -> torch.Tensor:
def forward(self, state: torch.Tensor, deterministic=False) -> torch.Tensor:
if self.soft:
hidden = self.net(state)
mu = self.mu(hidden)
Expand All @@ -330,36 +300,30 @@ def forward(
tanh_action = torch.tanh(action)
# change of variables formula (SAC paper, appendix C, eq 21)
log_prob = policy_dist.log_prob(action).sum(axis=-1)
log_prob = log_prob - torch.log(
1 - tanh_action.pow(2) + 1e-6
).sum(axis=-1)
log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum(
axis=-1
)
scaled_action = self.max_action * tanh_action
else:
hidden = self.net(state)
scaled_action = torch.tanh(self.mu(hidden))
log_prob = None
return scaled_action, log_prob

def log_prob(
self, state: torch.Tensor, action: torch.Tensor
) -> torch.Tensor:
def log_prob(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
hidden = self.net(state)
mu = self.mu(hidden)
log_sigma = self.log_sigma(hidden)
log_sigma = torch.clip(log_sigma, -5, 2)
policy_dist = Normal(mu, torch.exp(log_sigma))
action = torch.clip(
action, -self.max_action + 1e-6, self.max_action - 1e-6
)
action = torch.clip(action, -self.max_action + 1e-6, self.max_action - 1e-6)
log_prob = policy_dist.log_prob(torch.arctanh(action)).sum(axis=-1)
log_prob = log_prob - torch.log(1 - action.pow(2) + 1e-6).sum(axis=-1)
return log_prob

@torch.no_grad()
def act(self, state: np.ndarray, device: str = "cpu") -> np.ndarray:
state = torch.tensor(
state.reshape(1, -1), device=device, dtype=torch.float32
)
state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
return self(state, deterministic=True)[0].cpu().data.numpy().flatten()


Expand Down Expand Up @@ -392,9 +356,7 @@ def _alpha_loss(self, state: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
action, action_log_prob = self.actor(state)

loss = (
-self.log_alpha * (action_log_prob + self.target_entropy)
).mean()
loss = (-self.log_alpha * (action_log_prob + self.target_entropy)).mean()

return loss

Expand Down Expand Up @@ -453,19 +415,15 @@ def train(config: TrainConfig):
keep_best_trajectories(dataset, config.frac, config.discount)

if config.normalize:
state_mean, state_std = compute_mean_std(
dataset["observations"], eps=1e-3
)
state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
else:
state_mean, state_std = 0, 1

if "next_observations" not in dataset.keys():
dataset["next_observations"] = np.roll(
dataset["observations"], shift=-1, axis=0
) # Terminals/timeouts block next observations
print(
"Loaded next state observations from current state observations."
)
print("Loaded next state observations from current state observations.")

dataset["observations"] = normalize_states(
dataset["observations"], state_mean, state_std
Expand All @@ -485,9 +443,7 @@ def train(config: TrainConfig):
if config.checkpoints_path is not None:
print(f"Checkpoints path: {config.checkpoints_path}")
os.makedirs(config.checkpoints_path, exist_ok=True)
with open(
os.path.join(config.checkpoints_path, "config.yaml"), "w"
) as f:
with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
pyrallis.dump(config, f)

max_action = float(env.action_space.high[0])
Expand All @@ -496,9 +452,9 @@ def train(config: TrainConfig):
seed = config.seed
set_seed(seed, env)

actor = Actor(
state_dim, action_dim, max_action, config.actor_LN, config.soft
).to(config.device)
actor = Actor(state_dim, action_dim, max_action, config.actor_LN, config.soft).to(
config.device
)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

kwargs = {
Expand Down Expand Up @@ -547,15 +503,9 @@ def train(config: TrainConfig):
"epoch": int((t + 1) / 1000),
}
if hasattr(env, "get_normalized_score"):
normalized_score = (
env.get_normalized_score(eval_returns) * 100.0
)
eval_log["eval/normalized_score_mean"] = np.mean(
normalized_score
)
eval_log["eval/normalized_score_std"] = np.std(
normalized_score
)
normalized_score = env.get_normalized_score(eval_returns) * 100.0
eval_log["eval/normalized_score_mean"] = np.mean(normalized_score)
eval_log["eval/normalized_score_std"] = np.std(normalized_score)

wandb.log(eval_log)
print("---------------------------------------")
Expand Down
Loading

0 comments on commit f2aa673

Please sign in to comment.