Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] CQL compatibility with compile #2553

Open
wants to merge 31 commits into
base: gh/vmoens/36/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 83 additions & 54 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule

from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
Expand Down Expand Up @@ -69,6 +72,9 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create agent
model = make_cql_model(cfg, train_env, eval_env, device)
del train_env
if hasattr(eval_env, "start"):
# To set the number of threads to the definitive value
eval_env.start()

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
Expand All @@ -81,81 +87,104 @@ def main(cfg: "DictConfig"): # noqa: F821
alpha_prime_optim,
) = make_continuous_cql_optimizer(cfg, loss_module)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
# Group optimizers
optimizer = group_optimizers(
policy_optim, critic_optim, alpha_optim, alpha_prime_optim
)

gradient_steps = cfg.optim.gradient_steps
policy_eval_start = cfg.optim.policy_eval_start
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
for i in range(gradient_steps):
pbar.update(1)
# sample data
data = replay_buffer.sample()
# compute loss
loss_vals = loss_module(data.clone().to(device))
def update(data, policy_eval_start, iteration):
loss_vals = loss_module(data.to(device))

# official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
if i >= policy_eval_start:
actor_loss = loss_vals["loss_actor"]
else:
actor_loss = loss_vals["loss_actor_bc"]
actor_loss = torch.where(
iteration >= policy_eval_start,
loss_vals["loss_actor"],
loss_vals["loss_actor_bc"],
)
q_loss = loss_vals["loss_qvalue"]
cql_loss = loss_vals["loss_cql"]

q_loss = q_loss + cql_loss
loss_vals["q_loss"] = q_loss

# update model
alpha_loss = loss_vals["loss_alpha"]
alpha_prime_loss = loss_vals["loss_alpha_prime"]
if alpha_prime_loss is None:
alpha_prime_loss = 0

alpha_optim.zero_grad()
alpha_loss.backward()
alpha_optim.step()
loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss

policy_optim.zero_grad()
actor_loss.backward()
policy_optim.step()
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

if alpha_prime_optim is not None:
alpha_prime_optim.zero_grad()
alpha_prime_loss.backward(retain_graph=True)
alpha_prime_optim.step()
# update qnet_target params
target_net_updater.step()

critic_optim.zero_grad()
# TODO: we have the option to compute losses independently retain is not needed?
q_loss.backward(retain_graph=False)
critic_optim.step()
return loss.detach(), loss_vals.detach()

loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
compile_mode = None
if cfg.compile.compile:
if cfg.compile.compile_mode not in (None, ""):
compile_mode = cfg.compile.compile_mode
elif cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
update = CudaGraphModule(update, warmup=50)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)

gradient_steps = cfg.optim.gradient_steps
policy_eval_start = cfg.optim.policy_eval_start
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
policy_eval_start = torch.tensor(policy_eval_start, device=device)
for i in range(gradient_steps):
pbar.update(1)
# sample data
with timeit("sample"):
data = replay_buffer.sample()

with timeit("update"):
# compute loss
i_device = torch.tensor(i, device=device)
loss, loss_vals = update(
data.to(device), policy_eval_start=policy_eval_start, iteration=i_device
)

# log metrics
to_log = {
"loss": loss.item(),
"loss_actor_bc": loss_vals["loss_actor_bc"].item(),
"loss_actor": loss_vals["loss_actor"].item(),
"loss_qvalue": q_loss.item(),
"loss_cql": cql_loss.item(),
"loss_alpha": alpha_loss.item(),
"loss_alpha_prime": alpha_prime_loss.item(),
"loss": loss.cpu(),
**loss_vals.cpu(),
}

# update qnet_target params
target_net_updater.step()

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward

log_metrics(logger, to_log, i)
with timeit("log/eval"):
if i % evaluation_interval == 0:
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward

with timeit("log"):
if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, to_log, i)
if i % 200 == 0:
timeit.print()
timeit.erase()

pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
Expand Down
Loading
Loading