diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index d07b40595bc..7ff3f23b7a5 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -50,7 +50,7 @@ ) # Anything from 2.5, incl. nightlies, allows for fullgraph -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", autouse=True) def set_default_device(): cur_device = torch.get_default_device() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 42ef4301c4d..fcdffed57a0 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import hydra +from tensordict.nn import CudaGraphModule from torchrl._utils import logger as torchrl_logger from torchrl.record import VideoRecorder @@ -15,9 +16,9 @@ def main(cfg: "DictConfig"): # noqa: F821 import torch.optim import tqdm - from tensordict import TensorDict + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import A2CLoss @@ -25,7 +26,11 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_parallel_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = cfg.loss.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) # Correct for frame_skip frame_skip = 4 @@ -35,28 +40,12 @@ def main(cfg: "DictConfig"): # noqa: F821 test_interval = cfg.logger.test_interval // frame_skip # Create models (check utils_atari.py) - actor, critic, critic_head = make_ppo_models(cfg.env.env_name) - actor, critic, critic_head = ( - actor.to(device), - critic.to(device), - critic_head.to(device), - ) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), - policy=actor, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - ) + actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), + storage=LazyTensorStorage(frames_per_batch, device=device), sampler=sampler, batch_size=mini_batch_size, ) @@ -67,6 +56,7 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=True, + vectorized=not cfg.loss.compile, ) loss_module = A2CLoss( actor_network=actor, @@ -83,9 +73,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizer optim = torch.optim.Adam( loss_module.parameters(), - lr=cfg.optim.lr, + lr=torch.tensor(cfg.optim.lr, device=device), weight_decay=cfg.optim.weight_decay, eps=cfg.optim.eps, + capturable=device.type == "cuda", ) # Create logger @@ -115,6 +106,56 @@ def main(cfg: "DictConfig"): # noqa: F821 ) test_env.eval() + # update function + def update(batch, max_grad_norm=cfg.optim.max_grad_norm): + # Forward pass A2C loss + loss = loss_module(batch) + + loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + loss_sum.backward() + gn = torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad(set_to_none=True) + + return ( + loss.select("loss_critic", "loss_entropy", "loss_objective") + .detach() + .set("grad_norm", gn) + ) + + if cfg.loss.compile: + compile_mode = cfg.loss.compile_mode + if compile_mode in ("", None): + if cfg.loss.cudagraphs: + compile_mode = None + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + adv_module = torch.compile(adv_module, mode=compile_mode) + + if cfg.loss.cudagraphs: + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + adv_module = CudaGraphModule(adv_module) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + policy_device=device, + compile_policy=cfg.loss.compile_mode if cfg.loss.compile else False, + cudagraph_policy=cfg.loss.cudagraphs, + ) + # Main loop collected_frames = 0 num_network_updates = 0 @@ -122,9 +163,14 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = (total_frames // frames_per_batch) * num_mini_batches + lr = cfg.optim.lr sampling_start = time.time() - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + data = next(c_iter) log_info = {} sampling_time = time.time() - sampling_start @@ -144,59 +190,53 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - losses = TensorDict({}, batch_size=[num_mini_batches]) + losses = [] training_start = time.time() # Compute GAE - with torch.no_grad(): + with torch.no_grad(), timeit("advantage"): data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg.optim.anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - num_network_updates += 1 - - # Forward pass A2C loss - loss = loss_module(batch) - losses[k] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) - - # Backward pass - loss_sum.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm - ) - - # Update the networks - optim.step() - optim.zero_grad() - + with timeit("emptying"): + data_buffer.empty() + with timeit("extending"): + data_buffer.extend(data_reshape) + + with timeit("optim"): + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + with timeit("optim - lr"): + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"].copy_(lr * alpha) + + num_network_updates += 1 + + with timeit("optim - update"): + torch.compiler.cudagraph_mark_step_begin() + loss = update(batch) + losses.append(loss) + + if i % 200 == 0: + timeit.print() + timeit.erase() # Get training losses training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + losses = torch.stack(losses).float().mean() + for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { - "train/lr": alpha * cfg.optim.lr, + "train/lr": lr * alpha, "train/sampling_time": sampling_time, "train/training_time": training_time, + **timeit.todict(prefix="time"), } ) @@ -223,7 +263,6 @@ def main(cfg: "DictConfig"): # noqa: F821 for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - collector.update_policy_weights_() sampling_start = time.time() collector.shutdown() diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 2b390d39d2a..b6bb4b88efd 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import hydra +from tensordict.nn import CudaGraphModule from torchrl._utils import logger as torchrl_logger from torchrl.record import VideoRecorder @@ -15,9 +16,8 @@ def main(cfg: "DictConfig"): # noqa: F821 import torch.optim import tqdm - from tensordict import TensorDict from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import A2CLoss @@ -26,31 +26,27 @@ def main(cfg: "DictConfig"): # noqa: F821 from utils_mujoco import eval_model, make_env, make_ppo_models # Define paper hyperparameters - device = "cpu" if not torch.cuda.device_count() else "cuda" + + device = cfg.loss.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size total_network_updates = ( cfg.collector.total_frames // cfg.collector.frames_per_batch ) * num_mini_batches # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, device), - policy=actor, - frames_per_batch=cfg.collector.frames_per_batch, - total_frames=cfg.collector.total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, + actor, critic = make_ppo_models( + cfg.env.env_name, device=device, compile=cfg.loss.compile ) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.collector.frames_per_batch), + storage=LazyTensorStorage(cfg.collector.frames_per_batch, device=device), sampler=sampler, batch_size=cfg.loss.mini_batch_size, ) @@ -61,6 +57,7 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=False, + vectorized=not cfg.loss.compile, ) loss_module = A2CLoss( actor_network=actor, @@ -71,8 +68,16 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + actor_optim = torch.optim.Adam( + actor.parameters(), + lr=torch.tensor(cfg.optim.lr, device=device), + capturable=device.type == "cuda", + ) + critic_optim = torch.optim.Adam( + critic.parameters(), + lr=torch.tensor(cfg.optim.lr, device=device), + capturable=device.type == "cuda", + ) # Create logger logger = None @@ -99,7 +104,58 @@ def main(cfg: "DictConfig"): # noqa: F821 logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] ), ) + + def update(batch): + # Forward pass A2C loss + loss = loss_module(batch) + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss.get("loss_entropy", 0.0) + + # Backward pass + (actor_loss + critic_loss).backward() + + # Update the networks + actor_optim.step() + critic_optim.step() + + actor_optim.zero_grad(set_to_none=True) + critic_optim.zero_grad(set_to_none=True) + return loss.select("loss_critic", "loss_objective").detach() # , "loss_entropy" + + compile_mode = None + if cfg.loss.compile: + compile_mode = cfg.loss.compile_mode + if compile_mode in ("", None): + if cfg.loss.cudagraphs: + compile_mode = None + else: + compile_mode = "reduce-overhead" + + update = torch.compile(update, mode=compile_mode) + actor = torch.compile(actor, mode=compile_mode) + adv_module = torch.compile(adv_module, mode=compile_mode) + + if cfg.loss.cudagraphs: + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10) + actor = CudaGraphModule(actor, warmup=10) + adv_module = CudaGraphModule(adv_module) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + trust_policy=True, + compile_policy=compile_mode if cfg.loss.compile else False, + cudagraph_policy=cfg.loss.cudagraphs, + ) + test_env.eval() + lr = cfg.optim.lr # Main loop collected_frames = 0 @@ -128,7 +184,7 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - losses = TensorDict({}, batch_size=[num_mini_batches]) + losses = [] training_start = time.time() # Compute GAE @@ -139,42 +195,24 @@ def main(cfg: "DictConfig"): # noqa: F821 # Update the data buffer data_buffer.extend(data_reshape) - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) + for batch in data_buffer: # Linearly decrease the learning rate and clip epsilon alpha = 1.0 if cfg.optim.anneal_lr: alpha = 1 - (num_network_updates / total_network_updates) for group in actor_optim.param_groups: - group["lr"] = cfg.optim.lr * alpha + group["lr"].copy_(lr * alpha) for group in critic_optim.param_groups: - group["lr"] = cfg.optim.lr * alpha + group["lr"].copy_(lr * alpha) num_network_updates += 1 - - # Forward pass A2C loss - loss = loss_module(batch) - losses[k] = loss.select( - "loss_critic", "loss_objective" # , "loss_entropy" - ).detach() - critic_loss = loss["loss_critic"] - actor_loss = loss["loss_objective"] # + loss["loss_entropy"] - - # Backward pass - actor_loss.backward() - critic_loss.backward() - - # Update the networks - actor_optim.step() - critic_optim.step() - actor_optim.zero_grad() - critic_optim.zero_grad() + torch.compiler.cudagraph_mark_step_begin() + loss = update(batch) + losses.append(loss) # Get training losses training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + losses = torch.stack(losses).float().mean() for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( @@ -209,8 +247,8 @@ def main(cfg: "DictConfig"): # noqa: F821 for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - collector.update_policy_weights_() sampling_start = time.time() + torch.compiler.cudagraph_mark_step_begin() collector.shutdown() if not test_env.is_closed: diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index dd0f43b52cb..5a7586ee95d 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -34,3 +34,7 @@ loss: critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 + compile: False + compile_mode: + cudagraphs: False + device: diff --git a/sota-implementations/a2c/config_mujoco.yaml b/sota-implementations/a2c/config_mujoco.yaml index 03a0bde32c5..a42087b2631 100644 --- a/sota-implementations/a2c/config_mujoco.yaml +++ b/sota-implementations/a2c/config_mujoco.yaml @@ -31,3 +31,7 @@ loss: critic_coef: 0.25 entropy_coef: 0.0 loss_critic_type: l2 + compile: False + compile_mode: default + cudagraphs: False + device: diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 6a09ff715e4..bf7e23cd8f9 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -86,7 +86,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): # -------------------------------------------------------------------- -def make_ppo_modules_pixels(proof_environment): +def make_ppo_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment): num_outputs = proof_environment.action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec.space.low.to(device), + "high": proof_environment.action_spec.space.high.to(device), } # Define input keys @@ -113,14 +113,16 @@ def make_ppo_modules_pixels(proof_environment): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], + device=device, ) - common_cnn_output = common_cnn(torch.ones(input_shape)) + common_cnn_output = common_cnn(torch.ones(input_shape, device=device)) common_mlp = MLP( in_features=common_cnn_output.shape[-1], activation_class=torch.nn.ReLU, activate_last_layer=True, out_features=512, num_cells=[], + device=device, ) common_mlp_output = common_mlp(common_cnn_output) @@ -137,6 +139,7 @@ def make_ppo_modules_pixels(proof_environment): out_features=num_outputs, activation_class=torch.nn.ReLU, num_cells=[], + device=device, ) policy_module = TensorDictModule( module=policy_net, @@ -148,7 +151,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=Composite(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec.to(device)), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -161,6 +164,7 @@ def make_ppo_modules_pixels(proof_environment): in_features=common_mlp_output.shape[-1], out_features=1, num_cells=[], + device=device, ) value_module = ValueOperator( value_net, @@ -170,11 +174,11 @@ def make_ppo_modules_pixels(proof_environment): return common_module, policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device): proof_environment = make_parallel_env(env_name, 1, device="cpu") common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment + proof_environment, device=device ) # Wrap modules in a single ActorCritic operator @@ -185,8 +189,8 @@ def make_ppo_models(env_name): ) with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) + td = proof_environment.fake_tensordict().expand(1) + td = actor_critic(td.to(device)) del td actor = actor_critic.get_policy_operator() diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 996706ce4f9..e16bcefc890 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -48,7 +48,7 @@ def make_env( # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, device, *, compile: bool = False): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -57,9 +57,10 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec.space.low.to(device), + "high": proof_environment.action_spec.space.high.to(device), "tanh_loc": False, + "safe_tanh": not compile, } # Define policy architecture @@ -68,6 +69,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=num_outputs, # predict only loc num_cells=[64, 64], + device=device, ) # Initialize policy weights @@ -79,7 +81,9 @@ def make_ppo_models_state(proof_environment): # Add state-independent normal scale policy_mlp = torch.nn.Sequential( policy_mlp, - AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + AddStateIndependentNormalScale( + proof_environment.action_spec.shape[-1], device=device + ), ) # Add probabilistic sampling of the actions @@ -90,7 +94,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=Composite(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec.to(device)), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -103,6 +107,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=1, num_cells=[64, 64], + device=device, ) # Initialize value weights @@ -120,9 +125,11 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device, *, compile: bool = False): proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) + actor, critic = make_ppo_models_state( + proof_environment, device=device, compile=compile + ) return actor, critic diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 3af44ee0ed7..d057b05c8c9 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -99,10 +99,15 @@ def print(prefix=None): # noqa: T202 logger.info(" -- ".join(strings)) @classmethod - def todict(cls, percall=True): + def todict(cls, percall=True, prefix=None): + def _make_key(key): + if prefix: + return f"{prefix}/{key}" + return key + if percall: - return {key: val[0] for key, val in cls._REG.items()} - return {key: val[1] for key, val in cls._REG.items()} + return {_make_key(key): val[0] for key, val in cls._REG.items()} + return {_make_key(key): val[1] for key, val in cls._REG.items()} @staticmethod def erase(): diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 995f245a8ac..07d339761b0 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -309,7 +309,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if _reward is not None: reward = reward + _reward - terminated, truncated, done, do_break = self.read_done( terminated=terminated, truncated=truncated, done=done ) @@ -323,7 +322,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # if truncated/terminated is not in the keys, we just don't pass it even if it # is defined. if terminated is None: - terminated = done + terminated = done.clone() if truncated is not None: obs_dict["truncated"] = truncated obs_dict["done"] = done diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index f1724326d2a..ad25f4a4d07 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -1423,7 +1423,7 @@ def _make_compatible_policy( env_maker=None, env_maker_kwargs=None, ): - if trust_policy: + if trust_policy or isinstance(policy, torch._dynamo.eval_frame.OptimizedModule): return policy if policy is None: input_spec = None diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index 367765812bb..4cd5184ff88 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from tensordict.nn import NormalParamExtractor +from torch import distributions as torch_dist from .continuous import ( Delta, @@ -33,3 +34,16 @@ OneHotCategorical, ) } + +HAS_ENTROPY = { + Delta: False, + IndependentNormal: True, + TanhDelta: False, + TanhNormal: False, + TruncatedNormal: False, + MaskedCategorical: False, + MaskedOneHotCategorical: False, + OneHotCategorical: True, + torch_dist.Categorical: True, + torch_dist.Normal: True, +} diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 8b0d5654b8d..32862ffe1c3 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -403,15 +403,16 @@ def __init__( event_dims = min(1, loc.ndim) err_msg = "TanhNormal high values must be strictly greater than low values" - if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): - if not (high > low).all(): - raise RuntimeError(err_msg) - elif isinstance(high, Number) and isinstance(low, Number): - if not high > low: - raise RuntimeError(err_msg) - else: - if not all(high > low): - raise RuntimeError(err_msg) + if not is_dynamo_compiling(): + if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): + if not (high > low).all(): + raise RuntimeError(err_msg) + elif isinstance(high, Number) and isinstance(low, Number): + if not high > low: + raise RuntimeError(err_msg) + else: + if not all(high > low): + raise RuntimeError(err_msg) high = torch.as_tensor(high, device=loc.device) low = torch.as_tensor(low, device=loc.device) diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index d2ffba30686..e91e26a3794 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - from enum import Enum from functools import wraps from typing import Any, Optional, Sequence, Union @@ -10,6 +9,9 @@ import torch import torch.distributions as D +from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits + + __all__ = [ "OneHotCategorical", "MaskedCategorical", @@ -83,6 +85,14 @@ class OneHotCategorical(D.Categorical): num_params: int = 1 + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, @@ -110,6 +120,12 @@ def mode(self) -> torch.Tensor: def deterministic_sample(self): return self.mode + def entropy(self): + min_real = torch.finfo(self.logits.dtype).min + logits = torch.clamp(self.logits, min=min_real) + p_log_p = logits * self.probs + return -p_log_p.sum(-1) + @_one_hot_wrapper(D.Categorical) def sample( self, sample_shape: Optional[Union[torch.Size, Sequence]] = None @@ -192,6 +208,14 @@ class MaskedCategorical(D.Categorical): -2.1972, -2.1972]) """ + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, @@ -363,6 +387,14 @@ class MaskedOneHotCategorical(MaskedCategorical): -2.1972, -2.1972]) """ + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index c823788b4c2..d9472bdcde8 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -20,12 +20,14 @@ from tensordict.utils import NestedKey from torch import distributions as d +from torchrl.modules.distributions import HAS_ENTROPY from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, + _get_default_device, _reduce, default_value_kwargs, distance_loss, @@ -316,10 +318,7 @@ def __init__( self.entropy_bonus = entropy_bonus and entropy_coef self.reduction = reduction - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + device = _get_default_device(self) self.register_buffer( "entropy_coef", torch.as_tensor(entropy_coef, device=device) @@ -347,7 +346,11 @@ def __init__( raise ValueError( f"clip_value must be a float or a scalar tensor, got {clip_value}." ) - self.register_buffer("clip_value", clip_value) + self.register_buffer( + "clip_value", torch.as_tensor(clip_value, device=device) + ) + else: + self.clip_value = None @property def functional(self): @@ -398,9 +401,9 @@ def reset(self) -> None: pass def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: - try: + if HAS_ENTROPY.get(type(dist), False): entropy = dist.entropy() - except NotImplementedError: + else: x = dist.rsample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x) if is_tensor_collection(log_prob): @@ -456,7 +459,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: old_state_value = old_state_value.clone() # TODO: if the advantage is gathered by forward, this introduces an - # overhead that we could easily reduce. + # overhead that we could easily reduce. target_return = tensordict.get( self.tensor_keys.value_target, None ) # TODO: None soon to be removed @@ -487,7 +490,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: loss_value, clip_fraction = _clip_value_loss( old_state_value, state_value, - self.clip_value.to(state_value.device), + self.clip_value, target_return, loss_value, self.loss_critic_type, @@ -541,6 +544,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams hp = dict(default_value_kwargs(value_type)) hp.update(hyperparams) + device = _get_default_device(self) + hp["device"] = device + if hasattr(self, "gamma"): hp["gamma"] = self.gamma if value_type == ValueEstimators.TD1: diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 31954005195..e022d8078df 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -408,11 +408,12 @@ def __init__(self, network: nn.Module) -> None: def __enter__(self) -> None: if self.mode: - self.network.requires_grad_(False) + self.params = TensorDict.from_module(self.network) + self.params.data.to_module(self.network) def __exit__(self, exc_type, exc_val, exc_tb) -> None: if self.mode: - self.network.requires_grad_() + self.params.to_module(self.network) class hold_out_params(_context_manager): @@ -583,3 +584,10 @@ def _clip_value_loss( # Chose the most pessimistic value prediction between clipped and non-clipped loss_value = torch.max(loss_value, loss_value_clipped) return loss_value, clip_fraction + + +def _get_default_device(net): + for p in net.parameters(): + return p.device + else: + return torch.get_default_device() diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index e396b7e1fcc..c90f16911dc 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -37,6 +37,10 @@ vtrace_advantage_estimate, ) +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling try: from torch import vmap @@ -69,92 +73,6 @@ def new_func(self, *args, **kwargs): return new_func -def _call_value_nets( - value_net: TensorDictModuleBase, - data: TensorDictBase, - params: TensorDictBase, - next_params: TensorDictBase, - single_call: bool, - value_key: NestedKey, - detach_next: bool, - vmap_randomness: str = "error", -): - in_keys = value_net.in_keys - if single_call: - for i, name in enumerate(data.names): - if name == "time": - ndim = i + 1 - break - else: - ndim = None - if ndim is not None: - # get data at t and last of t+1 - idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) - idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) - idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) - data_in = torch.cat( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False)[idx0], - ], - ndim - 1, - ) - else: - if RL_WARNINGS: - warnings.warn( - "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " - "This warning can be turned off by setting the environment variable RL_WARNINGS to False." - ) - ndim = data.ndim - idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) - idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),) - data_in = torch.cat( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False), - ], - ndim - 1, - ) - - # next_params should be None or be identical to params - if next_params is not None and next_params is not params: - raise ValueError( - "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." - ) - if params is not None: - with params.to_module(value_net): - value_est = value_net(data_in).get(value_key) - else: - value_est = value_net(data_in).get(value_key) - value, value_ = value_est[idx], value_est[idx_] - else: - data_in = torch.stack( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False), - ], - 0, - ) - if (params is not None) ^ (next_params is not None): - raise ValueError( - "params and next_params must be either both provided or not." - ) - elif params is not None: - params_stack = torch.stack([params, next_params], 0).contiguous() - data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( - data_in, params_stack - ) - else: - data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) - value_est = data_out.get(value_key) - value, value_ = value_est[0], value_est[1] - data.set(value_key, value) - data.set(("next", value_key), value_) - if detach_next: - value_ = value_.detach() - return value, value_ - - def _call_actor_net( actor_net: TensorDictModuleBase, data: TensorDictBase, @@ -432,6 +350,9 @@ def _next_value(self, tensordict, target_params, kwargs): @property def vmap_randomness(self): if self._vmap_randomness is None: + if is_dynamo_compiling(): + self._vmap_randomness = "different" + return "different" do_break = False for val in self.__dict__.values(): if isinstance(val, torch.nn.Module): @@ -467,6 +388,99 @@ def _get_time_dim(self, time_dim: int | None, data: TensorDictBase): return i return data.ndim - 1 + def _call_value_nets( + self, + data: TensorDictBase, + params: TensorDictBase, + next_params: TensorDictBase, + single_call: bool, + value_key: NestedKey, + detach_next: bool, + vmap_randomness: str = "error", + *, + value_net: TensorDictModuleBase | None = None, + ): + if value_net is None: + value_net = self.value_network + in_keys = value_net.in_keys + if single_call: + for i, name in enumerate(data.names): + if name == "time": + ndim = i + 1 + break + else: + ndim = None + if ndim is not None: + # get data at t and last of t+1 + idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) + idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) + idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False)[ + idx0 + ], + ], + ndim - 1, + ) + else: + if RL_WARNINGS: + warnings.warn( + "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " + "This warning can be turned off by setting the environment variable RL_WARNINGS to False." + ) + ndim = data.ndim + idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) + idx_ = (slice(None),) * (ndim - 1) + ( + slice(data.shape[ndim - 1], None), + ) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + ndim - 1, + ) + + # next_params should be None or be identical to params + if next_params is not None and next_params is not params: + raise ValueError( + "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." + ) + if params is not None: + with params.to_module(value_net): + value_est = value_net(data_in).get(value_key) + else: + value_est = value_net(data_in).get(value_key) + value, value_ = value_est[idx], value_est[idx_] + else: + data_in = torch.stack( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + 0, + ) + if (params is not None) ^ (next_params is not None): + raise ValueError( + "params and next_params must be either both provided or not." + ) + elif params is not None: + params_stack = torch.stack([params, next_params], 0).contiguous() + data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( + data_in, params_stack + ) + else: + data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) + value_est = data_out.get(value_key) + value, value_ = value_est[0], value_est[1] + data.set(value_key, value) + data.set(("next", value_key), value_) + if detach_next: + value_ = value_.detach() + return value, value_ + class TD0Estimator(ValueEstimatorBase): """Temporal Difference (TD(0)) estimate of advantage function. @@ -623,8 +637,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -837,8 +850,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1063,8 +1075,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1155,7 +1166,7 @@ class GAE(ValueEstimatorBase): pass detached parameters for functional modules. vectorized (bool, optional): whether to use the vectorized version of the - lambda return. Default is `True`. + lambda return. Default is `True` if not compiling. skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` @@ -1205,7 +1216,7 @@ def __init__( value_network: TensorDictModule, average_gae: bool = False, differentiable: bool = False, - vectorized: bool = True, + vectorized: bool | None = None, skip_existing: bool | None = None, advantage_key: NestedKey = None, value_target_key: NestedKey = None, @@ -1229,6 +1240,16 @@ def __init__( self.vectorized = vectorized self.time_dim = time_dim + @property + def vectorized(self): + if is_dynamo_compiling(): + return False + return self._vectorized + + @vectorized.setter + def vectorized(self, value): + self._vectorized = value + @_self_set_skip_existing @_self_set_grad_enabled @dispatch @@ -1328,10 +1349,10 @@ def forward( with hold_out_net(self.value_network) if ( params is None and target_params is None ) else nullcontext(): + # with torch.no_grad(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1417,8 +1438,7 @@ def value_estimate( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1682,8 +1702,7 @@ def forward( with hold_out_net(self.value_network): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index ddd688610c2..bb737d7c20d 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -12,6 +12,10 @@ import torch +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling __all__ = [ "generalized_advantage_estimate", @@ -147,7 +151,7 @@ def generalized_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -181,19 +185,25 @@ def generalized_advantage_estimate( def _geom_series_like(t, r, thr): """Creates a geometric series of the form [1, gammalmbda, gammalmbda**2] with the shape of `t`. - Drops all elements which are smaller than `thr`. + Drops all elements which are smaller than `thr` (unless in compile mode). """ - if isinstance(r, torch.Tensor): - r = r.item() - - if r == 0.0: - return torch.zeros_like(t) - elif r >= 1.0: - lim = t.numel() + if is_dynamo_compiling(): + if isinstance(r, torch.Tensor): + rs = r.expand_as(t) + else: + rs = torch.full_like(t, r) else: - lim = int(math.log(thr) / math.log(r)) + if isinstance(r, torch.Tensor): + r = r.item() + + if r == 0.0: + return torch.zeros_like(t) + elif r >= 1.0: + lim = t.numel() + else: + lim = int(math.log(thr) / math.log(r)) - rs = torch.full_like(t[:lim], r) + rs = torch.full_like(t[:lim], r) rs[0] = 1.0 rs = rs.cumprod(0) rs = rs.unsqueeze(-1) @@ -292,7 +302,7 @@ def vec_generalized_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -391,7 +401,7 @@ def td0_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -435,7 +445,7 @@ def td0_return_estimate( """ if done is not None and terminated is None: - terminated = done + terminated = done.clone() warnings.warn( "done for td0_return_estimate is deprecated. Pass ``terminated`` instead." ) @@ -499,7 +509,7 @@ def td1_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) not_done = (~done).int() @@ -596,7 +606,7 @@ def td1_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -726,7 +736,7 @@ def vec_td1_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -804,7 +814,7 @@ def td_lambda_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) @@ -910,7 +920,7 @@ def td_lambda_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -1046,7 +1056,7 @@ def vec_td_lambda_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) @@ -1196,7 +1206,7 @@ def vec_td_lambda_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape diff --git a/torchrl/objectives/value/utils.py b/torchrl/objectives/value/utils.py index ec1d33069a5..7910611e36d 100644 --- a/torchrl/objectives/value/utils.py +++ b/torchrl/objectives/value/utils.py @@ -301,7 +301,10 @@ def _fill_tensor(tensor): device=tensor.device, ) mask_expand = expand_right(mask, (*mask.shape, *tensor.shape[1:])) - return torch.masked_scatter(empty_tensor, mask_expand, tensor.reshape(-1)) + # return torch.where(mask_expand, tensor, 0.0) + # return torch.masked_scatter(empty_tensor, mask_expand, tensor.reshape(-1)) + empty_tensor[mask_expand] = tensor.reshape(-1) + return empty_tensor if isinstance(tensor, TensorDictBase): tensor = tensor.apply(_fill_tensor, batch_size=[*shape])