Skip to content

Commit

Permalink
[Feature] Discrete SAC compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 5b3d6a2100ad9cb96b9dd00d798ff628add59ca7
Pull Request resolved: #2569
  • Loading branch information
vmoens committed Nov 15, 2024
1 parent 3efbf04 commit 6ab0011
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 82 deletions.
1 change: 0 additions & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def update(sampled_tensordict):
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
torch.compiler.cudagraph_mark_step_begin()
tensordict = next(c_iter)
pbar.update(tensordict.numel())
# update weights of the inference policy
Expand Down
3 changes: 3 additions & 0 deletions sota-implementations/discrete_sac/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ network:
hidden_sizes: [256, 256]
activation: relu
device: null
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
Expand Down
170 changes: 90 additions & 80 deletions sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
The helper functions are coded in the utils.py associated with this script.
"""
import time

import warnings

import hydra
import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger

from tensordict.nn import CudaGraphModule
from torchrl._utils import timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
dump_video,
Expand Down Expand Up @@ -73,9 +74,6 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create TD3 loss
loss_module, target_net_updater = make_loss_module(cfg, model)

# Create off-policy collector
collector = make_collector(cfg, train_env, model[0])

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
Expand All @@ -89,9 +87,57 @@ def main(cfg: "DictConfig"): # noqa: F821
optimizer_actor, optimizer_critic, optimizer_alpha = make_optimizer(
cfg, loss_module
)
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
del optimizer_actor, optimizer_critic, optimizer_alpha

def update(sampled_tensordict):
optimizer.zero_grad(set_to_none=True)

# Compute loss
loss_out = loss_module(sampled_tensordict)

actor_loss, q_loss, alpha_loss = (
loss_out["loss_actor"],
loss_out["loss_qvalue"],
loss_out["loss_alpha"],
)

# Update critic
(q_loss + actor_loss + alpha_loss).backward()
optimizer.step()

# Update target params
target_net_updater.step()

return loss_out.detach()

compile_mode = None
if cfg.network.compile:
compile_mode = cfg.network.compile_mode
if compile_mode in ("", None):
if cfg.network.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.network.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Create off-policy collector
collector = make_collector(
cfg,
train_env,
model[0],
compile=compile_mode is not None,
compile_mode=compile_mode,
cudagraphs=cfg.network.cudagraphs,
)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -106,129 +152,93 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch

sampling_start = time.time()
for i, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
collected_data = next(c_iter)

# Update weights of the inference policy
collector.update_policy_weights_()
current_frames = collected_data.numel()

pbar.update(tensordict.numel())
pbar.update(current_frames)

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_data = collected_data.reshape(-1)
with timeit("rb - extend"):
# Add to replay buffer
replay_buffer.extend(collected_data)
collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
alpha_losses,
) = ([], [], [])
tds = []
for _ in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
loss_out = loss_module(sampled_tensordict)

actor_loss, q_loss, alpha_loss = (
loss_out["loss_actor"],
loss_out["loss_qvalue"],
loss_out["loss_alpha"],
)

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.item())
with timeit("rb - sample"):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()

# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
sampled_tensordict = sampled_tensordict.to(device)
loss_out = update(sampled_tensordict)

actor_losses.append(actor_loss.item())

# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
optimizer_alpha.step()

alpha_losses.append(alpha_loss.item())

# Update target params
target_net_updater.step()
tds.append(loss_out)

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)
tds = torch.stack(tds).mean()

training_time = time.time() - training_start
# Logging
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
collected_data["next", "done"]
if collected_data["next", "done"].any()
else collected_data["next", "truncated"]
)
episode_rewards = tensordict["next", "episode_reward"][episode_end]
episode_rewards = collected_data["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
episode_length = collected_data["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/a_loss"] = tds["loss_actor"]
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]

# Evaluation
prev_test_frame = ((i - 1) * frames_per_batch) // eval_iter
cur_test_frame = (i * frames_per_batch) // eval_iter
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if i % 50 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
11 changes: 10 additions & 1 deletion sota-implementations/discrete_sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,14 @@ def make_environment(cfg, logger=None):
# ---------------------------


def make_collector(cfg, train_env, actor_model_explore):
def make_collector(
cfg,
train_env,
actor_model_explore,
compile=False,
compile_mode=None,
cudagraphs=False,
):
"""Make collector."""
device = cfg.collector.device
if device in ("", None):
Expand All @@ -129,6 +136,8 @@ def make_collector(cfg, train_env, actor_model_explore):
reset_at_each_iter=cfg.collector.reset_at_each_iter,
device=device,
storing_device="cpu",
compile_policy=False if not compile else {"mode": compile_mode},
cudagraph_policy=cudagraphs,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down

0 comments on commit 6ab0011

Please sign in to comment.