From 14b27751391c29018e131f46b9466a58302ee409 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 19 Nov 2024 18:05:36 +0000 Subject: [PATCH 1/5] [Refactor] Deprecate recurrent_mode API to use decorators/CMs instead ghstack-source-id: 80f705e022abc111df3960fc09576d5e266ed4dd Pull Request resolved: https://github.com/pytorch/rl/pull/2584 --- docs/source/reference/modules.rst | 2 + test/test_cost.py | 24 ++++ test/test_tensordictmodules.py | 54 ++++++++- test/test_transforms.py | 6 +- torchrl/_utils.py | 23 ++++ torchrl/envs/transforms/transforms.py | 3 +- torchrl/modules/__init__.py | 2 + torchrl/modules/tensordict_module/__init__.py | 11 +- torchrl/modules/tensordict_module/rnn.py | 105 ++++++++++++++++-- torchrl/objectives/common.py | 8 +- tutorials/sphinx-tutorials/dqn_with_rnn.py | 21 ++-- tutorials/sphinx-tutorials/export.py | 4 - 12 files changed, 230 insertions(+), 33 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index ee78c68835f..c79e4f42c49 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -373,6 +373,8 @@ algorithms, such as DQN, DDPG or Dreamer. OnlineDTActor RSSMPosterior RSSMPrior + set_recurrent_mode + recurrent_mode Multi-agent-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_cost.py b/test/test_cost.py index 1e157fd7a2f..598b9ba004d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -47,6 +47,7 @@ DistributionalQValueActor, OneHotCategorical, QValueActor, + recurrent_mode, SafeSequential, WorldModelWrapper, ) @@ -15507,6 +15508,29 @@ def test_set_deprecated_keys(self, adv, kwargs): class TestBase: + def test_decorators(self): + class MyLoss(LossModule): + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + assert recurrent_mode() + assert exploration_type() is ExplorationType.DETERMINISTIC + return TensorDict() + + def actor_loss(self, tensordict: TensorDictBase) -> TensorDictBase: + assert recurrent_mode() + assert exploration_type() is ExplorationType.DETERMINISTIC + return TensorDict() + + def something_loss(self, tensordict: TensorDictBase) -> TensorDictBase: + assert recurrent_mode() + assert exploration_type() is ExplorationType.DETERMINISTIC + return TensorDict() + + loss = MyLoss() + loss.forward(None) + loss.actor_loss(None) + loss.something_loss(None) + assert not recurrent_mode() + @pytest.mark.parametrize("expand_dim", [None, 2]) @pytest.mark.parametrize("compare_against", [True, False]) @pytest.mark.skipif(not _has_functorch, reason="functorch is needed for expansion") diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index ec9322500b4..d3b7b7850f4 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -36,6 +36,7 @@ OnlineDTActor, ProbabilisticActor, SafeModule, + set_recurrent_mode, TanhDelta, TanhNormal, ValueOperator, @@ -729,6 +730,31 @@ def test_errs(self): with pytest.raises(KeyError, match="is_init"): lstm_module(td) + @pytest.mark.parametrize("default_val", [False, True, None]) + def test_set_recurrent_mode(self, default_val): + lstm_module = LSTMModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden0", "hidden1"], + out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], + default_recurrent_mode=default_val, + ) + assert lstm_module.recurrent_mode is bool(default_val) + with set_recurrent_mode(True): + assert lstm_module.recurrent_mode + with set_recurrent_mode(False): + assert not lstm_module.recurrent_mode + with set_recurrent_mode("recurrent"): + assert lstm_module.recurrent_mode + with set_recurrent_mode("sequential"): + assert not lstm_module.recurrent_mode + assert lstm_module.recurrent_mode + assert not lstm_module.recurrent_mode + assert lstm_module.recurrent_mode + assert lstm_module.recurrent_mode is bool(default_val) + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_set_temporal_mode(self): lstm_module = LSTMModule( input_size=3, @@ -754,7 +780,8 @@ def test_python_cudnn(self): num_layers=2, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], - ).set_recurrent_mode(True) + default_recurrent_mode=True, + ) obs = torch.rand(10, 20, 3) hidden0 = torch.rand(10, 20, 2, 12) @@ -1109,6 +1136,31 @@ def test_errs(self): with pytest.raises(KeyError, match="is_init"): gru_module(td) + @pytest.mark.parametrize("default_val", [False, True, None]) + def test_set_recurrent_mode(self, default_val): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + default_recurrent_mode=default_val, + ) + assert gru_module.recurrent_mode is bool(default_val) + with set_recurrent_mode(True): + assert gru_module.recurrent_mode + with set_recurrent_mode(False): + assert not gru_module.recurrent_mode + with set_recurrent_mode("recurrent"): + assert gru_module.recurrent_mode + with set_recurrent_mode("sequential"): + assert not gru_module.recurrent_mode + assert gru_module.recurrent_mode + assert not gru_module.recurrent_mode + assert gru_module.recurrent_mode + assert gru_module.recurrent_mode is bool(default_val) + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_set_temporal_mode(self): gru_module = GRUModule( input_size=3, diff --git a/test/test_transforms.py b/test/test_transforms.py index 56a39218f5f..8b2ada8c93a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -10885,7 +10885,8 @@ def _make_gru_module(self, input_size=4, hidden_size=4, device="cpu"): in_keys=["observation", "rhs", "is_init"], out_keys=["output", ("next", "rhs")], device=device, - ).set_recurrent_mode(True) + default_recurrent_mode=True, + ) def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"): return LSTMModule( @@ -10895,7 +10896,8 @@ def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"): in_keys=["observation", "rhs_h", "rhs_c", "is_init"], out_keys=["output", ("next", "rhs_h"), ("next", "rhs_c")], device=device, - ).set_recurrent_mode(True) + default_recurrent_mode=True, + ) def _make_batch(self, batch_size: int = 2, sequence_length: int = 5): observation = torch.randn(batch_size, sequence_length + 1, 4) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 0b4dd03a636..d37aebb862f 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -15,9 +15,11 @@ import os import pickle import sys +import threading import time import traceback import warnings +from contextlib import nullcontext from copy import copy from distutils.util import strtobool from functools import wraps @@ -32,6 +34,11 @@ from tensordict.utils import NestedKey from torch import multiprocessing as mp +try: + from torch.compiler import is_compiling +except ImportError: + from torch._dynamo import is_compiling + LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO") logger = logging.getLogger("torchrl") logger.setLevel(getattr(logging, LOGGING_LEVEL)) @@ -827,3 +834,19 @@ def _make_ordinal_device(device: torch.device): if device.type == "mps" and device.index is None: return torch.device("mps", index=0) return device + + +class _ContextManager: + def __init__(self): + self._mode: Any | None = None + self._lock = threading.Lock() + + def get_mode(self) -> Any | None: + cm = self._lock if not is_compiling() else nullcontext() + with cm: + return self._mode + + def set_mode(self, type: Any | None) -> None: + cm = self._lock if not is_compiling() else nullcontext() + with cm: + self._mode = type diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e02c88c5330..7bdd25591cd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -7411,7 +7411,8 @@ class BurnInTransform(Transform): ... hidden_size=10, ... in_keys=["observation", "hidden"], ... out_keys=["intermediate", ("next", "hidden")], - ... ).set_recurrent_mode(True) + ... default_recurrent_mode=True, + ... ) >>> burn_in_transform = BurnInTransform( ... modules=[gru_module], ... burn_in=5, diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 4cb6366f817..edf90a4e85b 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -80,10 +80,12 @@ QValueActor, QValueHook, QValueModule, + recurrent_mode, SafeModule, SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, SafeSequential, + set_recurrent_mode, TanhModule, ValueOperator, VmapModule, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 202f84fd173..3fb1559833a 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -34,6 +34,15 @@ SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, ) -from .rnn import GRU, GRUCell, GRUModule, LSTM, LSTMCell, LSTMModule +from .rnn import ( + GRU, + GRUCell, + GRUModule, + LSTM, + LSTMCell, + LSTMModule, + recurrent_mode, + set_recurrent_mode, +) from .sequence import SafeSequential from .world_models import WorldModelWrapper diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 6a99e85812b..f4ceb648665 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -4,7 +4,9 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Optional, Tuple +import typing +import warnings +from typing import Any, Optional, Tuple import torch import torch.nn.functional as F @@ -18,6 +20,7 @@ from torch import nn, Tensor from torch.nn.modules.rnn import RNNCellBase +from torchrl._utils import _ContextManager, _DecoratorContextManager from torchrl.data.tensor_specs import Unbounded from torchrl.objectives.value.functional import ( _inv_pad_sequence, @@ -376,6 +379,9 @@ class LSTMModule(ModuleBase): device (torch.device or compatible): the device of the module. lstm (torch.nn.LSTM, optional): an LSTM instance to be wrapped. Exclusive with other nn.LSTM arguments. + default_recurrent_mode (bool, optional): if provided, the recurrent mode if it hasn't been overridden + by the :class:`~torchrl.modules.set_recurrent_mode` context manager / decorator. + Defaults to ``False``. Attributes: recurrent_mode: Returns the recurrent mode of the module. @@ -451,6 +457,7 @@ def __init__( out_keys=None, device=None, lstm=None, + default_recurrent_mode: bool | None = None, ): super().__init__() if lstm is not None: @@ -524,7 +531,7 @@ def __init__( in_keys = in_keys + ["is_init"] self.in_keys = in_keys self.out_keys = out_keys - self._recurrent_mode = False + self._recurrent_mode = default_recurrent_mode def make_python_based(self) -> LSTMModule: """Transforms the LSTM layer in its python-based version. @@ -647,12 +654,15 @@ def make_tuple(key): @property def recurrent_mode(self): - return self._recurrent_mode + rm = recurrent_mode() + if rm is None: + return bool(self._recurrent_mode) + return rm @recurrent_mode.setter def recurrent_mode(self, value): raise RuntimeError( - "recurrent_mode cannot be changed in-place. Call `module.set" + "recurrent_mode cannot be changed in-place. Please use the set_recurrent_mode context manager." ) @property @@ -662,7 +672,7 @@ def temporal_mode(self): ) def set_recurrent_mode(self, mode: bool = True): - """Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs). + """[DEPRECATED - use :class:`torchrl.modules.set_recurrent_mode` context manager instead] Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs). A copy is created such that the module can be used with divergent behavior in various parts of the code (inference vs training): @@ -692,7 +702,13 @@ def set_recurrent_mode(self, mode: bool = True): ... >>> torch.testing.assert_close(td_inf["hidden0"], traj_td[..., -1]["next", "hidden0"]) """ - if mode is self._recurrent_mode: + warnings.warn( + "The lstm.set_recurrent_mode() API is deprecated and will be removed in v0.8. " + "To set the recurent mode, use the :class:`~torchrl.modules.set_recurrent_mode` context manager or " + "the `default_recurrent_mode` keyword argument in the constructor.", + category=DeprecationWarning, + ) + if mode is self.recurrent_mode: return self out = LSTMModule(lstm=self.lstm, in_keys=self.in_keys, out_keys=self.out_keys) out._recurrent_mode = mode @@ -1155,6 +1171,9 @@ class GRUModule(ModuleBase): device (torch.device or compatible): the device of the module. gru (torch.nn.GRU, optional): a GRU instance to be wrapped. Exclusive with other nn.GRU arguments. + default_recurrent_mode (bool, optional): if provided, the recurrent mode if it hasn't been overridden + by the :class:`~torchrl.modules.set_recurrent_mode` context manager / decorator. + Defaults to ``False``. Attributes: recurrent_mode: Returns the recurrent mode of the module. @@ -1256,6 +1275,7 @@ def __init__( out_keys=None, device=None, gru=None, + default_recurrent_mode: bool | None = None, ): super().__init__() if gru is not None: @@ -1326,7 +1346,7 @@ def __init__( in_keys = in_keys + ["is_init"] self.in_keys = in_keys self.out_keys = out_keys - self._recurrent_mode = False + self._recurrent_mode = default_recurrent_mode def make_python_based(self) -> GRUModule: """Transforms the GRU layer in its python-based version. @@ -1444,12 +1464,15 @@ def make_tuple(key): @property def recurrent_mode(self): - return self._recurrent_mode + rm = recurrent_mode() + if rm is None: + return bool(self._recurrent_mode) + return rm @recurrent_mode.setter def recurrent_mode(self, value): raise RuntimeError( - "recurrent_mode cannot be changed in-place. Call `module.set" + "recurrent_mode cannot be changed in-place. Please use the set_recurrent_mode context manager." ) @property @@ -1459,7 +1482,7 @@ def temporal_mode(self): ) def set_recurrent_mode(self, mode: bool = True): - """Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs). + """[DEPRECATED - use :class:`torchrl.modules.set_recurrent_mode` context manager instead] Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs). A copy is created such that the module can be used with divergent behavior in various parts of the code (inference vs training): @@ -1488,7 +1511,13 @@ def set_recurrent_mode(self, mode: bool = True): ... >>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"]) """ - if mode is self._recurrent_mode: + warnings.warn( + "The gru.set_recurrent_mode() API is deprecated and will be removed in v0.8. " + "To set the recurent mode, use the :class:`~torchrl.modules.set_recurrent_mode` context manager or " + "the `default_recurrent_mode` keyword argument in the constructor.", + category=DeprecationWarning, + ) + if mode is self.recurrent_mode: return self out = GRUModule(gru=self.gru, in_keys=self.in_keys, out_keys=self.out_keys) out._recurrent_mode = mode @@ -1598,3 +1627,57 @@ def _gru( ) out = [y, hidden] return tuple(out) + + +# Recurrent mode manager +recurrent_mode_state_manager = _ContextManager() + + +def recurrent_mode() -> bool | None: + """Returns the current sampling type.""" + return recurrent_mode_state_manager.get_mode() + + +class set_recurrent_mode(_DecoratorContextManager): + """Context manager for setting RNNs recurrent mode. + + Args: + mode (bool, "recurrent" or "stateful"): the recurrent mode to be used within the context manager. + `"recurrent"` leads to `mode=True` and `"stateful"` leads to `mode=False`. + An RNN executed with recurrent_mode "on" assumes that the data comes in time batches, otherwise + it is assumed that each data element in a tensordict is independent of the others. + The default value of this context manager is ``True``. + The default recurrent mode is ``None``, i.e., the default recurrent mode of the RNN is used + (see :class:`~torchrl.modules.LSTMModule` and :class:`~torchrl.modules.GRUModule` constructors). + + .. seealso:: :class:`~torchrl.modules.recurrent_mode``. + + .. note:: All of TorchRL methods are decorated with ``set_recurrent_mode(True)`` by default. + + """ + + def __init__( + self, mode: bool | typing.Literal["recurrent", "sequential"] | None = True + ) -> None: + super().__init__() + if isinstance(mode, str): + if mode.lower() in ("recurrent",): + mode = True + elif mode.lower() in ("sequential",): + mode = False + else: + raise ValueError( + f"Unsupported recurrent mode. Must be a bool, or one of {('recurrent', 'sequential')}" + ) + self.mode = mode + + def clone(self) -> set_recurrent_mode: + # override this method if your children class takes __init__ parameters + return type(self)(self.mode) + + def __enter__(self) -> None: + self.prev = recurrent_mode_state_manager.get_mode() + recurrent_mode_state_manager.set_mode(self.mode) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + recurrent_mode_state_manager.set_mode(self.prev) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 57310a5fc3d..d54671f569b 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -21,6 +21,7 @@ from torch.nn import Parameter from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import set_recurrent_mode from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators from torchrl.objectives.value import ValueEstimatorBase @@ -46,7 +47,9 @@ def _updater_check_forward_prehook(module, *args, **kwargs): def _forward_wrapper(func): @functools.wraps(func) def new_forward(self, *args, **kwargs): - with set_exploration_type(self.deterministic_sampling_mode): + with set_exploration_type(self.deterministic_sampling_mode), set_recurrent_mode( + True + ): return func(self, *args, **kwargs) return new_forward @@ -56,6 +59,9 @@ class _LossMeta(abc.ABCMeta): def __init__(cls, name, bases, attr_dict): super().__init__(name, bases, attr_dict) cls.forward = _forward_wrapper(cls.forward) + for name, value in cls.__dict__.items(): + if not name.startswith("_") and name.endswith("loss"): + setattr(cls, name, _forward_wrapper(value)) class LossModule(TensorDictModuleBase, metaclass=_LossMeta): diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index 8931f483384..58c47f68321 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -317,7 +317,7 @@ # # We can now put things together in a :class:`~tensordict.nn.TensorDictSequential` # -stoch_policy = Seq(feature, lstm, mlp, qval) +policy = Seq(feature, lstm, mlp, qval) ###################################################################### # DQN being a deterministic algorithm, exploration is a crucial part of it. @@ -330,7 +330,7 @@ annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2 ) stoch_policy = TensorDictSequential( - stoch_policy, + policy, exploration_module, ) @@ -338,20 +338,17 @@ # Using the model for the loss # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# The model as we've built it is well equipped to be used in sequential settings. +# The model as we've built it is well-equipped to be used in sequential settings. # However, the class :class:`torch.nn.LSTM` can use a cuDNN-optimized backend # to run the RNN sequence faster on GPU device. We would not want to miss # such an opportunity to speed up our training loop! -# To use it, we just need to tell the LSTM module to run on "recurrent-mode" -# when used by the loss. -# As we'll usually want to have two copies of the LSTM module, we do this by -# calling a :meth:`~torchrl.modules.LSTMModule.set_recurrent_mode` method that -# will return a new instance of the LSTM (with shared weights) that will -# assume that the input data is sequential in nature. # -policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval) - -###################################################################### +# By default, torchrl losses will use this when executing any +# :class:`~torchrl.modules.LSTMModule` or :class:`~torchrl.modules.GRUModule` +# forward call. If you need to control this manually, the RNN modules are sensitive +# to a context manager/decorator, :class:`~torchrl.modules.set_recurrent_mode`, +# that handles the behaviour of the underlying RNN module. +# # Because we still have a couple of uninitialized parameters we should # initialize them before creating an optimizer and such. # diff --git a/tutorials/sphinx-tutorials/export.py b/tutorials/sphinx-tutorials/export.py index 0a4390abdfc..48dd8723ffc 100644 --- a/tutorials/sphinx-tutorials/export.py +++ b/tutorials/sphinx-tutorials/export.py @@ -265,10 +265,6 @@ in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", "hidden0", "hidden1"], ) -##################################### -# We set the recurrent mode to ``False`` to allow the module to read inputs one-by-one and not in batch. -# -lstm = lstm.set_recurrent_mode(False) ##################################### # If the LSTM module is not python based but CuDNN (:class:`~torch.nn.LSTM`), the :meth:`~torchrl.modules.LSTMModule.make_python_based` From a47b32c073ebb74878b8e7329ef2e79997c3e783 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 20 Nov 2024 11:49:44 +0000 Subject: [PATCH 2/5] [BugFix] make buffers zero-dim in exploration modules ghstack-source-id: fd2705eb9132169da4871b27b354f7895c644061 Pull Request resolved: https://github.com/pytorch/rl/pull/2591 --- test/_utils_internal.py | 4 +- .../modules/tensordict_module/exploration.py | 38 +++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index dea0d136844..a3476b31110 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -167,8 +167,8 @@ def get_available_devices(): def get_default_devices(): num_cuda = torch.cuda.device_count() if num_cuda == 0: - if torch.mps.is_available(): - return [torch.device("mps:0")] + # if torch.mps.is_available(): + # return [torch.device("mps:0")] return [torch.device("cpu")] elif num_cuda == 1: return [torch.device("cuda:0")] diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 9acfae1aa21..a1879519271 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -112,10 +112,10 @@ def __init__( super().__init__() - self.register_buffer("eps_init", torch.as_tensor([eps_init])) - self.register_buffer("eps_end", torch.as_tensor([eps_end])) + self.register_buffer("eps_init", torch.as_tensor(eps_init)) + self.register_buffer("eps_end", torch.as_tensor(eps_end)) self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32)) + self.register_buffer("eps", torch.as_tensor(eps_init, dtype=torch.float32)) if spec is not None: if not isinstance(spec, Composite) and len(self.out_keys) >= 1: @@ -275,13 +275,13 @@ def __init__( super().__init__(policy) if sigma_end > sigma_init: raise RuntimeError("sigma should decrease over time or be constant") - self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device)) - self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device)) + self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device)) + self.register_buffer("sigma_end", torch.tensor(sigma_end, device=device)) self.annealing_num_steps = annealing_num_steps - self.register_buffer("mean", torch.tensor([mean], device=device)) - self.register_buffer("std", torch.tensor([std], device=device)) + self.register_buffer("mean", torch.tensor(mean, device=device)) + self.register_buffer("std", torch.tensor(std, device=device)) self.register_buffer( - "sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device) + "sigma", torch.tensor(sigma_init, dtype=torch.float32, device=device) ) self.action_key = action_key self.out_keys = list(self.td_module.out_keys) @@ -423,13 +423,13 @@ def __init__( super().__init__() - self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device)) - self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device)) + self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device)) + self.register_buffer("sigma_end", torch.tensor(sigma_end, device=device)) self.annealing_num_steps = annealing_num_steps - self.register_buffer("mean", torch.tensor([mean], device=device)) - self.register_buffer("std", torch.tensor([std], device=device)) + self.register_buffer("mean", torch.tensor(mean, device=device)) + self.register_buffer("std", torch.tensor(std, device=device)) self.register_buffer( - "sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device) + "sigma", torch.tensor(sigma_init, dtype=torch.float32, device=device) ) if spec is not None: @@ -628,8 +628,8 @@ def __init__( key=action_key, device=device, ) - self.register_buffer("eps_init", torch.tensor([eps_init], device=device)) - self.register_buffer("eps_end", torch.tensor([eps_end], device=device)) + self.register_buffer("eps_init", torch.tensor(eps_init, device=device)) + self.register_buffer("eps_end", torch.tensor(eps_end, device=device)) if self.eps_end > self.eps_init: raise ValueError( "eps should decrease over time or be constant, " @@ -637,7 +637,7 @@ def __init__( ) self.annealing_num_steps = annealing_num_steps self.register_buffer( - "eps", torch.tensor([eps_init], dtype=torch.float32, device=device) + "eps", torch.tensor(eps_init, dtype=torch.float32, device=device) ) self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys self.is_init_key = is_init_key @@ -840,8 +840,8 @@ def __init__( device=device, ) - self.register_buffer("eps_init", torch.tensor([eps_init], device=device)) - self.register_buffer("eps_end", torch.tensor([eps_end], device=device)) + self.register_buffer("eps_init", torch.tensor(eps_init, device=device)) + self.register_buffer("eps_end", torch.tensor(eps_end, device=device)) if self.eps_end > self.eps_init: raise ValueError( "eps should decrease over time or be constant, " @@ -849,7 +849,7 @@ def __init__( ) self.annealing_num_steps = annealing_num_steps self.register_buffer( - "eps", torch.tensor([eps_init], dtype=torch.float32, device=device) + "eps", torch.tensor(eps_init, dtype=torch.float32, device=device) ) self.in_keys = [self.ou.key] From d30599ec0799ea7a895590a558854cd87f0d73b0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 20 Nov 2024 17:10:15 +0000 Subject: [PATCH 3/5] [BugFix] action_spec_unbatched whenever necessary ghstack-source-id: ec87794dabaf5023dac85cfc898a7c000e93331d Pull Request resolved: https://github.com/pytorch/rl/pull/2592 --- .../collectors/multi_nodes/ray_train.py | 4 +- sota-implementations/a2c/utils_atari.py | 4 +- sota-implementations/a2c/utils_mujoco.py | 4 +- sota-implementations/cql/utils.py | 2 +- sota-implementations/crossq/utils.py | 4 +- .../decision_transformer/utils.py | 2 +- sota-implementations/dreamer/dreamer.py | 10 +- sota-implementations/gail/ppo_utils.py | 4 +- sota-implementations/iql/utils.py | 4 +- sota-implementations/multiagent/iql.py | 6 +- .../multiagent/maddpg_iddpg.py | 8 +- sota-implementations/multiagent/mappo_ippo.py | 6 +- sota-implementations/multiagent/qmix_vdn.py | 6 +- sota-implementations/multiagent/sac.py | 12 +- sota-implementations/ppo/utils_atari.py | 4 +- sota-implementations/ppo/utils_mujoco.py | 4 +- sota-implementations/redq/utils.py | 2 +- sota-implementations/sac/utils.py | 4 +- test/mocking_classes.py | 46 ++----- test/test_env.py | 24 ++-- test/test_libs.py | 7 + torchrl/envs/common.py | 126 +++++++++++++++--- tutorials/sphinx-tutorials/coding_ppo.py | 4 +- .../multiagent_competitive_ddpg.py | 4 +- tutorials/sphinx-tutorials/multiagent_ppo.py | 6 +- 25 files changed, 191 insertions(+), 116 deletions(-) diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index 5697d88dc61..e52584c4ac4 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -85,8 +85,8 @@ in_keys=["loc", "scale"], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.action_spec.space.low, - "high": env.action_spec.space.high, + "low": env.action_spec_unbatched.space.low, + "high": env.action_spec_unbatched.space.high, }, return_log_prob=True, ) diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 99c3ce2338c..a0cea48b510 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -101,8 +101,8 @@ def make_ppo_modules_pixels(proof_environment, device): num_outputs = proof_environment.action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low.to(device), - "high": proof_environment.action_spec.space.high.to(device), + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), } # Define input keys diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 87587d092f0..645bc806265 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -57,8 +57,8 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low.to(device), - "high": proof_environment.action_spec.space.high.to(device), + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), "tanh_loc": False, "safe_tanh": True, } diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index c1d6fb52024..51134b6828d 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -191,7 +191,7 @@ def make_offline_replay_buffer(rb_cfg): def make_cql_model(cfg, train_env, eval_env, device="cpu"): model_cfg = cfg.model - action_spec = train_env.action_spec + action_spec = train_env.action_spec_unbatched actor_net, q_net = make_cql_modules_state(model_cfg, eval_env) in_keys = ["observation"] diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index 9883bc50b17..483bf257c63 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -147,9 +147,7 @@ def make_crossQ_agent(cfg, train_env, device): """Make CrossQ agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.action_spec - if train_env.batch_size: - action_spec = action_spec[(0,) * len(train_env.batch_size)] + action_spec = train_env.action_spec_unbatched actor_net_kwargs = { "num_cells": cfg.network.actor_hidden_sizes, "out_features": 2 * action_spec.shape[-1], diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index ee2cc6e424c..7f905c72366 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -393,7 +393,7 @@ def make_dt_model(cfg): make_base_env(env_cfg), env_cfg, obs_loc=0, obs_std=1 ) - action_spec = proof_environment.action_spec + action_spec = proof_environment.action_spec_unbatched for key, value in proof_environment.observation_spec.items(): if key == "observation": state_dim = value.shape[-1] diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 992abea64e0..d97066b87c5 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -20,7 +20,7 @@ ) # mixed precision training -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.utils import clip_grad_norm_ from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -321,6 +321,14 @@ def compile_rssms(module): t_collect_init = time.time() + test_env.close() + train_env.close() + collector.shutdown() + + del test_env + del train_env + del collector + if __name__ == "__main__": main() diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index 7986738f8e6..63310113e98 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec_unbatched.space.low, + "high": proof_environment.action_spec_unbatched.space.high, "tanh_loc": False, } diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index a24c6168375..ff84d0d8138 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -195,9 +195,7 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): model_cfg = cfg.model in_keys = ["observation"] - action_spec = train_env.action_spec - if train_env.batch_size: - action_spec = action_spec[(0,) * len(train_env.batch_size)] + action_spec = train_env.action_spec_unbatched actor_net, q_net, value_net = make_iql_modules_state(model_cfg, eval_env) out_keys = ["loc", "scale"] diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index 39750c5d425..66cc3b6659e 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -72,7 +72,7 @@ def train(cfg: "DictConfig"): # noqa: F821 # Policy net = MultiAgentMLP( n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], - n_agent_outputs=env.action_spec.space.n, + n_agent_outputs=env.full_action_spec["agents", "action"].space.n, n_agents=env.n_agents, centralised=False, share_params=cfg.model.shared_parameters, @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821 ("agents", "action_value"), ("agents", "chosen_action_value"), ], - spec=env.unbatched_action_spec, + spec=env.full_action_spec_unbatched, action_space=None, ) qnet = SafeSequential(module, value_module) @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.unbatched_action_spec, + spec=env.full_action_spec_unbatched, ), ) diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index 6199a888344..1485e3e8c0b 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -91,13 +91,13 @@ def train(cfg: "DictConfig"): # noqa: F821 ) policy = ProbabilisticActor( module=policy_module, - spec=env.unbatched_action_spec, + spec=env.full_action_spec_unbatched, in_keys=[("agents", "param")], out_keys=[env.action_key], distribution_class=TanhDelta, distribution_kwargs={ - "low": env.unbatched_action_spec[("agents", "action")].space.low, - "high": env.unbatched_action_spec[("agents", "action")].space.high, + "low": env.full_action_spec_unbatched[("agents", "action")].space.low, + "high": env.full_action_spec_unbatched[("agents", "action")].space.high, }, return_log_prob=False, ) @@ -105,7 +105,7 @@ def train(cfg: "DictConfig"): # noqa: F821 policy_explore = TensorDictSequential( policy, AdditiveGaussianModule( - spec=env.unbatched_action_spec, + spec=env.full_action_spec_unbatched, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, device=cfg.train.device, diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index d2e218b843a..06cc2cd1fce 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -92,13 +92,13 @@ def train(cfg: "DictConfig"): # noqa: F821 ) policy = ProbabilisticActor( module=policy_module, - spec=env.unbatched_action_spec, + spec=env.full_action_spec_unbatched, in_keys=[("agents", "loc"), ("agents", "scale")], out_keys=[env.action_key], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.unbatched_action_spec[("agents", "action")].space.low, - "high": env.unbatched_action_spec[("agents", "action")].space.high, + "low": env.full_action_spec_unbatched[("agents", "action")].space.low, + "high": env.full_action_spec_unbatched[("agents", "action")].space.high, }, return_log_prob=True, ) diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index c5993f902c6..a6e24cf9414 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -72,7 +72,7 @@ def train(cfg: "DictConfig"): # noqa: F821 # Policy net = MultiAgentMLP( n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1], - n_agent_outputs=env.action_spec.space.n, + n_agent_outputs=env.full_action_spec["agents", "action"].space.n, n_agents=env.n_agents, centralised=False, share_params=cfg.model.shared_parameters, @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821 ("agents", "action_value"), ("agents", "chosen_action_value"), ], - spec=env.unbatched_action_spec, + spec=env.full_action_spec_unbatched, action_space=None, ) qnet = SafeSequential(module, value_module) @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.unbatched_action_spec, + spec=env.full_action_spec_unbatched, ), ) diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index cfafdd47c96..694083e5b0f 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -96,13 +96,13 @@ def train(cfg: "DictConfig"): # noqa: F821 policy = ProbabilisticActor( module=policy_module, - spec=env.unbatched_action_spec, + spec=env.full_action_spec_unbatched, in_keys=[("agents", "loc"), ("agents", "scale")], out_keys=[env.action_key], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.unbatched_action_spec[("agents", "action")].space.low, - "high": env.unbatched_action_spec[("agents", "action")].space.high, + "low": env.full_action_spec_unbatched[("agents", "action")].space.low, + "high": env.full_action_spec_unbatched[("agents", "action")].space.high, }, return_log_prob=True, ) @@ -146,7 +146,7 @@ def train(cfg: "DictConfig"): # noqa: F821 ) policy = ProbabilisticActor( module=policy_module, - spec=env.unbatched_action_spec, + spec=env.full_action_spec_unbatched, in_keys=[("agents", "logits")], out_keys=[env.action_key], distribution_class=OneHotCategorical @@ -194,7 +194,7 @@ def train(cfg: "DictConfig"): # noqa: F821 actor_network=policy, qvalue_network=value_module, delay_qvalue=True, - action_spec=env.unbatched_action_spec, + action_spec=env.full_action_spec_unbatched, ) loss_module.set_keys( state_action_value=("agents", "state_action_value"), @@ -209,7 +209,7 @@ def train(cfg: "DictConfig"): # noqa: F821 qvalue_network=value_module, delay_qvalue=True, num_actions=env.action_spec.space.n, - action_space=env.unbatched_action_spec, + action_space=env.full_action_spec_unbatched, ) loss_module.set_keys( action_value=("agents", "action_value"), diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 50f91ed49cd..debc8f9e211 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment): num_outputs = proof_environment.action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec_unbatched.space.low, + "high": proof_environment.action_spec_unbatched.space.high, } # Define input keys diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index a05d205b000..6c7a1b80fd7 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec_unbatched.space.low, + "high": proof_environment.action_spec_unbatched.space.high, "tanh_loc": False, } diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 2823858af60..9953fcb3112 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -410,7 +410,7 @@ def make_redq_model( default_policy_scale = cfg.network.default_policy_scale gSDE = cfg.exploration.gSDE - action_spec = proof_environment.action_spec + action_spec = proof_environment.action_spec_unbatched if actor_net_kwargs is None: actor_net_kwargs = {} diff --git a/sota-implementations/sac/utils.py b/sota-implementations/sac/utils.py index b8630f77ab0..d1dbb2db791 100644 --- a/sota-implementations/sac/utils.py +++ b/sota-implementations/sac/utils.py @@ -161,9 +161,7 @@ def make_sac_agent(cfg, train_env, eval_env, device): """Make SAC agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.action_spec - if train_env.batch_size: - action_spec = action_spec[(0,) * len(train_env.batch_size)] + action_spec = train_env.action_spec_unbatched actor_net_kwargs = { "num_cells": cfg.network.hidden_sizes, "out_features": 2 * action_spec.shape[-1], diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 225b978ec14..eb517429c08 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1040,7 +1040,9 @@ def _step( action = tensordict.get(self.action_key) self.count += action.to( dtype=torch.int, - device=self.action_spec.device if self.device is None else self.device, + device=self.full_action_spec[self.action_key].device + if self.device is None + else self.device, ) tensordict = TensorDict( source={ @@ -1388,17 +1390,17 @@ def _make_specs(self): obs_spec_unlazy = consolidate_spec(obs_specs) action_specs = torch.stack(action_specs, dim=0) - self.unbatched_observation_spec = Composite( + self.observation_spec_unbatched = Composite( lazy=obs_spec_unlazy, state=Unbounded(shape=(64, 64, 3)), device=self.device, ) - self.unbatched_action_spec = Composite( + self.action_spec_unbatched = Composite( lazy=action_specs, device=self.device, ) - self.unbatched_reward_spec = Composite( + self.reward_spec_unbatched = Composite( { "lazy": Composite( {"reward": Unbounded(shape=(self.n_nested_dim, 1))}, @@ -1407,7 +1409,7 @@ def _make_specs(self): }, device=self.device, ) - self.unbatched_done_spec = Composite( + self.done_spec_unbatched = Composite( { "lazy": Composite( { @@ -1423,19 +1425,6 @@ def _make_specs(self): device=self.device, ) - self.action_spec = self.unbatched_action_spec.expand( - *self.batch_size, *self.unbatched_action_spec.shape - ) - self.observation_spec = self.unbatched_observation_spec.expand( - *self.batch_size, *self.unbatched_observation_spec.shape - ) - self.reward_spec = self.unbatched_reward_spec.expand( - *self.batch_size, *self.unbatched_reward_spec.shape - ) - self.done_spec = self.unbatched_done_spec.expand( - *self.batch_size, *self.unbatched_done_spec.shape - ) - def get_agent_obs_spec(self, i): camera = Bounded(low=0, high=200, shape=(7, 7, 3)) vector_3d = Unbounded(shape=(3,)) @@ -1610,21 +1599,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): self.make_specs() - self.action_spec = self.unbatched_action_spec.expand( - *self.batch_size, *self.unbatched_action_spec.shape - ) - self.observation_spec = self.unbatched_observation_spec.expand( - *self.batch_size, *self.unbatched_observation_spec.shape - ) - self.reward_spec = self.unbatched_reward_spec.expand( - *self.batch_size, *self.unbatched_reward_spec.shape - ) - self.done_spec = self.unbatched_done_spec.expand( - *self.batch_size, *self.unbatched_done_spec.shape - ) - def make_specs(self): - self.unbatched_observation_spec = Composite( + self.observation_spec_unbatched = Composite( nested_1=Composite( observation=Bounded(low=0, high=200, shape=(self.nested_dim_1, 3)), shape=(self.nested_dim_1,), @@ -1642,7 +1618,7 @@ def make_specs(self): ), ) - self.unbatched_action_spec = Composite( + self.action_spec_unbatched = Composite( nested_1=Composite( action=Categorical(n=2, shape=(self.nested_dim_1,)), shape=(self.nested_dim_1,), @@ -1654,7 +1630,7 @@ def make_specs(self): action=OneHot(n=2), ) - self.unbatched_reward_spec = Composite( + self.reward_spec_unbatched = Composite( nested_1=Composite( gift=Unbounded(shape=(self.nested_dim_1, 1)), shape=(self.nested_dim_1,), @@ -1666,7 +1642,7 @@ def make_specs(self): reward=Unbounded(shape=(1,)), ) - self.unbatched_done_spec = Composite( + self.done_spec_unbatched = Composite( nested_1=Composite( done=Categorical( n=2, diff --git a/test/test_env.py b/test/test_env.py index 05d8308494a..ab854a3b4be 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3512,18 +3512,18 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi def test_single_env_spec(): env = NestedCountingEnv(batch_size=[3, 1, 7]) - assert not env.single_full_action_spec.shape - assert not env.single_full_done_spec.shape - assert not env.single_input_spec.shape - assert not env.single_full_observation_spec.shape - assert not env.single_output_spec.shape - assert not env.single_full_reward_spec.shape - - assert env.single_action_spec.shape - assert env.single_reward_spec.shape - - assert env.output_spec.is_in(env.single_output_spec.zeros(env.shape)) - assert env.input_spec.is_in(env.single_input_spec.zeros(env.shape)) + assert not env.full_action_spec_unbatched.shape + assert not env.full_done_spec_unbatched.shape + assert not env.input_spec_unbatched.shape + assert not env.full_observation_spec_unbatched.shape + assert not env.output_spec_unbatched.shape + assert not env.full_reward_spec_unbatched.shape + + assert env.action_spec_unbatched.shape + assert env.reward_spec_unbatched.shape + + assert env.output_spec.is_in(env.output_spec_unbatched.zeros(env.shape)) + assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape)) if __name__ == "__main__": diff --git a/test/test_libs.py b/test/test_libs.py index defa486da6a..b3ba8d54c3d 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2253,6 +2253,13 @@ def test_vmas_batch_size(self, scenario_name, num_envs, n_agents): max_steps=n_rollout_samples, return_contiguous=False if env.het_specs else True, ) + assert ( + env.full_action_spec_unbatched.shape == env.unbatched_action_spec.shape + ), ( + env.action_spec, + env.batch_size, + ) + env.close() if env.het_specs: diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index cf784f5659d..a7c004bfcc5 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -172,10 +172,11 @@ def __call__(cls, *args, **kwargs): # we create the done spec by adding a done/terminated entry if one is missing instance._create_done_specs() # we access lazy attributed to make sure they're built properly. - # This isn't done in `__init__` because we don't know if supre().__init__ + # This isn't done in `__init__` because we don't know if super().__init__ # will be called before or after the specs, batch size etc are set. _ = instance.done_spec - _ = instance.reward_spec + _ = instance.reward_keys + # _ = instance.action_keys _ = instance.state_spec if auto_reset: from torchrl.envs.transforms.transforms import ( @@ -658,7 +659,7 @@ def action_keys(self) -> List[NestedKey]: action_keys = self.__dict__.get("_action_keys") if action_keys is not None: return action_keys - keys = self.input_spec["full_action_spec"].keys(True, True) + keys = self.full_action_spec.keys(True, True) if not len(keys): raise AttributeError("Could not find action spec") keys = sorted(keys, key=_repr_by_depth) @@ -778,6 +779,14 @@ def action_spec(self) -> TensorSpec: if len(self.action_keys) > 1: out = action_spec else: + if len(self.action_keys) == 1 and self.action_keys[0] != "action": + warnings.warn( + "You are querying a non-trivial, single action_spec, i.e., there is only " + "one action known by the environment but it is not named `'action'`. " + "Currently, env.action_spec returns the leaf but for consistency with the " + "setter, this will return the full spec instead (from v0.8 and on).", + category=DeprecationWarning, + ) try: out = action_spec[self.action_key] except KeyError: @@ -807,7 +816,8 @@ def action_spec(self, value: TensorSpec) -> None: ) if value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " + "Please use `env.action_spec_unbatched = value` to set unbatched versions instead." ) if isinstance(value, Composite): @@ -984,6 +994,14 @@ def reward_spec(self) -> TensorSpec: if len(reward_keys) > 1 or not len(reward_keys): return reward_spec else: + if len(self.reward_keys) == 1 and self.reward_keys[0] != "reward": + warnings.warn( + "You are querying a non-trivial, single reward_spec, i.e., there is only " + "one reward known by the environment but it is not named `'reward'`. " + "Currently, env.reward_spec returns the leaf but for consistency with the " + "setter, this will return the full spec instead (from v0.8 and on).", + category=DeprecationWarning, + ) return reward_spec[self.reward_keys[0]] @reward_spec.setter @@ -1002,7 +1020,8 @@ def reward_spec(self, value: TensorSpec) -> None: ) if value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " + "Please use `env.reward_spec_unbatched = value` to set unbatched versions instead." ) if isinstance(value, Composite): for _ in value.values(True, True): # noqa: B007 @@ -1053,7 +1072,18 @@ def full_reward_spec(self) -> Composite: domain=continuous), device=None, shape=torch.Size([])), device=cpu, shape=torch.Size([])) """ - return self.output_spec["full_reward_spec"] + try: + return self.output_spec["full_reward_spec"] + except KeyError: + # populate the "reward" entry + # this will be raised if there is not full_reward_spec (unlikely) or no reward_key + # Since output_spec is lazily populated with an empty composite spec for + # reward_spec, the second case is much more likely to occur. + self.reward_spec = Unbounded( + shape=(*self.batch_size, 1), + device=self.device, + ) + return self.output_spec["full_reward_spec"] @full_reward_spec.setter def full_reward_spec(self, spec: Composite) -> None: @@ -1493,65 +1523,125 @@ def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec: return spec[idx] @property - def single_full_action_spec(self) -> Composite: + def full_action_spec_unbatched(self) -> Composite: """Returns the action spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.full_action_spec) + @full_action_spec_unbatched.setter + def full_action_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.full_action_spec = spec + @property - def single_action_spec(self) -> TensorSpec: + def action_spec_unbatched(self) -> TensorSpec: """Returns the action spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.action_spec) + @action_spec_unbatched.setter + def action_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.action_spec = spec + @property - def single_full_observation_spec(self) -> Composite: + def full_observation_spec_unbatched(self) -> Composite: """Returns the observation spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.full_action_spec) + @full_observation_spec_unbatched.setter + def full_observation_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.full_observation_spec = spec + @property - def single_observation_spec(self) -> Composite: + def observation_spec_unbatched(self) -> Composite: """Returns the observation spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.observation_spec) + @observation_spec_unbatched.setter + def observation_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.observation_spec = spec + @property - def single_full_reward_spec(self) -> Composite: + def full_reward_spec_unbatched(self) -> Composite: """Returns the reward spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.full_action_spec) + @full_reward_spec_unbatched.setter + def full_reward_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.full_reward_spec = spec + @property - def single_reward_spec(self) -> TensorSpec: + def reward_spec_unbatched(self) -> TensorSpec: """Returns the reward spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.reward_spec) + @reward_spec_unbatched.setter + def reward_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.reward_spec = spec + @property - def single_full_done_spec(self) -> Composite: + def full_done_spec_unbatched(self) -> Composite: """Returns the done spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.full_action_spec) + @full_done_spec_unbatched.setter + def full_done_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.full_done_spec = spec + @property - def single_done_spec(self) -> TensorSpec: + def done_spec_unbatched(self) -> TensorSpec: """Returns the done spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.done_spec) + @done_spec_unbatched.setter + def done_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.done_spec = spec + @property - def single_output_spec(self) -> Composite: + def output_spec_unbatched(self) -> Composite: """Returns the output spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.output_spec) + @output_spec_unbatched.setter + def output_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.output_spec = spec + @property - def single_input_spec(self) -> Composite: + def input_spec_unbatched(self) -> Composite: """Returns the input spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.input_spec) + @input_spec_unbatched.setter + def input_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.input_spec = spec + @property - def single_full_state_spec(self) -> Composite: + def full_state_spec_unbatched(self) -> Composite: """Returns the state spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.full_state_spec) + @full_state_spec_unbatched.setter + def full_state_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.full_state_spec = spec + @property - def single_state_spec(self) -> TensorSpec: + def state_spec_unbatched(self) -> TensorSpec: """Returns the state spec of the env as if it had no batch dimensions.""" return self._make_single_env_spec(self.state_spec) + @state_spec_unbatched.setter + def state_spec_unbatched(self, spec: Composite): + spec = spec.expand(self.batch_size + spec.shape) + self.state_spec = spec + def step(self, tensordict: TensorDictBase) -> TensorDictBase: """Makes a step in the environment. diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 25e72dc40f4..a0373ba4b46 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -431,8 +431,8 @@ in_keys=["loc", "scale"], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.action_spec.space.low, - "high": env.action_spec.space.high, + "low": env.action_spec_unbatched.space.low, + "high": env.action_spec_unbatched.space.high, }, return_log_prob=True, # we'll need the log-prob for the numerator of the importance weights diff --git a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py index 0d0c6360958..a7bd74a4deb 100644 --- a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py +++ b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py @@ -486,8 +486,8 @@ out_keys=[(group, "action")], distribution_class=TanhDelta, distribution_kwargs={ - "low": env.full_action_spec[group, "action"].space.low, - "high": env.full_action_spec[group, "action"].space.high, + "low": env.full_action_spec_unbatched[group, "action"].space.low, + "high": env.full_action_spec_unbatched[group, "action"].space.high, }, return_log_prob=False, ) diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index ec24de6cddd..e2ca3f6ecd8 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -445,13 +445,13 @@ policy = ProbabilisticActor( module=policy_module, - spec=env.unbatched_action_spec, + spec=env.action_spec_unbatched, in_keys=[("agents", "loc"), ("agents", "scale")], out_keys=[env.action_key], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.unbatched_action_spec[env.action_key].space.low, - "high": env.unbatched_action_spec[env.action_key].space.high, + "low": env.action_spec_unbatched[env.action_key].space.low, + "high": env.action_spec_unbatched[env.action_key].space.high, }, return_log_prob=True, log_prob_key=("agents", "sample_log_prob"), From a126a6f94594246f3fb4f9fd38f0700501ee5d58 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 20 Nov 2024 17:10:16 +0000 Subject: [PATCH 4/5] [Refactor] Use _unbatched in VMAS ghstack-source-id: 2190278de44ba59a3bc8d38398fddae9ecc42a84 Pull Request resolved: https://github.com/pytorch/rl/pull/2593 --- sota-implementations/multiagent/qmix_vdn.py | 2 +- torchrl/envs/libs/vmas.py | 61 +++++++++++++++------ 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index a6e24cf9414..1bcc2dbd10e 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -110,7 +110,7 @@ def train(cfg: "DictConfig"): # noqa: F821 if cfg.loss.mixer_type == "qmix": mixer = TensorDictModule( module=QMixer( - state_shape=env.unbatched_observation_spec[ + state_shape=env.observation_spec_unbatched[ "agents", "observation" ].shape, mixing_embed_dim=32, diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 8d2e3387e3c..140fb191cae 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -5,6 +5,7 @@ from __future__ import annotations import importlib.util +import warnings from typing import Dict, List, Optional, Union @@ -328,9 +329,9 @@ def _make_specs( self.group_map = self.group_map.get_group_map(self.agent_names) check_marl_grouping(self.group_map, self.agent_names) - self.unbatched_action_spec = Composite(device=self.device) - self.unbatched_observation_spec = Composite(device=self.device) - self.unbatched_reward_spec = Composite(device=self.device) + full_action_spec_unbatched = Composite(device=self.device) + full_observation_spec_unbatched = Composite(device=self.device) + full_reward_spec_unbatched = Composite(device=self.device) self.het_specs = False self.het_specs_map = {} @@ -341,18 +342,18 @@ def _make_specs( group_reward_spec, group_info_spec, ) = self._make_unbatched_group_specs(group) - self.unbatched_action_spec[group] = group_action_spec - self.unbatched_observation_spec[group] = group_observation_spec - self.unbatched_reward_spec[group] = group_reward_spec + full_action_spec_unbatched[group] = group_action_spec + full_observation_spec_unbatched[group] = group_observation_spec + full_reward_spec_unbatched[group] = group_reward_spec if group_info_spec is not None: - self.unbatched_observation_spec[(group, "info")] = group_info_spec + full_observation_spec_unbatched[(group, "info")] = group_info_spec group_het_specs = isinstance( group_observation_spec, StackedComposite ) or isinstance(group_action_spec, StackedComposite) self.het_specs_map[group] = group_het_specs self.het_specs = self.het_specs or group_het_specs - self.unbatched_done_spec = Composite( + full_done_spec_unbatched = Composite( { "done": Categorical( n=2, @@ -363,18 +364,42 @@ def _make_specs( }, ) - self.action_spec = self.unbatched_action_spec.expand( - *self.batch_size, *self.unbatched_action_spec.shape + self.full_action_spec_unbatched = full_action_spec_unbatched + self.full_observation_spec_unbatched = full_observation_spec_unbatched + self.full_reward_spec_unbatched = full_reward_spec_unbatched + self.full_done_spec_unbatched = full_done_spec_unbatched + + @property + def unbatched_action_spec(self): + warnings.warn( + "unbatched_action_spec is deprecated and will be removed in v0.9. " + "Please use full_action_spec_unbatched instead." ) - self.observation_spec = self.unbatched_observation_spec.expand( - *self.batch_size, *self.unbatched_observation_spec.shape + return self.full_action_spec_unbatched + + @property + def unbatched_observation_spec(self): + warnings.warn( + "unbatched_observation_spec is deprecated and will be removed in v0.9. " + "Please use full_observation_spec_unbatched instead." ) - self.reward_spec = self.unbatched_reward_spec.expand( - *self.batch_size, *self.unbatched_reward_spec.shape + return self.full_observation_spec_unbatched + + @property + def unbatched_reward_spec(self): + warnings.warn( + "unbatched_reward_spec is deprecated and will be removed in v0.9. " + "Please use full_reward_spec_unbatched instead." ) - self.done_spec = self.unbatched_done_spec.expand( - *self.batch_size, *self.unbatched_done_spec.shape + return self.full_reward_spec_unbatched + + @property + def unbatched_done_spec(self): + warnings.warn( + "unbatched_done_spec is deprecated and will be removed in v0.9. " + "Please use full_done_spec_unbatched instead." ) + return self.full_done_spec_unbatched def _make_unbatched_group_specs(self, group: str): # Agent specs @@ -618,7 +643,9 @@ def read_reward(self, rewards): def read_action(self, action, group: str = "agents"): if not self.continuous_actions and not self.categorical_actions: - action = self.unbatched_action_spec[group, "action"].to_categorical(action) + action = self.full_action_spec_unbatched[group, "action"].to_categorical( + action + ) agent_actions = action.unbind(dim=1) return agent_actions From b4b59444a5e894711ba6d062f9cddc6aafa0e095 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 20 Nov 2024 21:34:39 +0000 Subject: [PATCH 5/5] [Versioning] Fix moviepy 2.0 incompatibility ghstack-source-id: 1597321b84a812ce3b891c8f1f6851f5018cd956 Pull Request resolved: https://github.com/pytorch/rl/pull/2594 --- .github/unittest/linux/scripts/environment.yml | 2 +- .github/unittest/linux_distributed/scripts/environment.yml | 2 +- .github/unittest/linux_libs/scripts_envpool/environment.yml | 2 +- .github/unittest/linux_libs/scripts_gym/environment.yml | 2 +- .github/unittest/linux_libs/scripts_robohive/environment.yml | 2 +- .github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml | 2 +- .github/unittest/linux_sota/scripts/environment.yml | 2 +- README.md | 2 +- setup.py | 2 +- sota-check/README.md | 2 +- sota-implementations/multiagent/README.md | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 2234683a497..1e49def1635 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -9,7 +9,7 @@ dependencies: - future - cloudpickle - pygame - - moviepy + - moviepy<2.0.0 - tqdm - pytest - pytest-cov diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml index 76160f7a16a..d2d833b4ad9 100644 --- a/.github/unittest/linux_distributed/scripts/environment.yml +++ b/.github/unittest/linux_distributed/scripts/environment.yml @@ -9,7 +9,7 @@ dependencies: - future - cloudpickle - pygame - - moviepy + - moviepy<2.0.0 - tqdm - pytest - pytest-cov diff --git a/.github/unittest/linux_libs/scripts_envpool/environment.yml b/.github/unittest/linux_libs/scripts_envpool/environment.yml index 74a3c91cf06..b15b949ecbf 100644 --- a/.github/unittest/linux_libs/scripts_envpool/environment.yml +++ b/.github/unittest/linux_libs/scripts_envpool/environment.yml @@ -9,7 +9,7 @@ dependencies: - future - cloudpickle - pygame - - moviepy + - moviepy<2.0.0 - pytest-cov - pytest-mock - pytest-instafail diff --git a/.github/unittest/linux_libs/scripts_gym/environment.yml b/.github/unittest/linux_libs/scripts_gym/environment.yml index d30aa6d0f91..edca0978673 100644 --- a/.github/unittest/linux_libs/scripts_gym/environment.yml +++ b/.github/unittest/linux_libs/scripts_gym/environment.yml @@ -11,7 +11,7 @@ dependencies: - future - cloudpickle - pygame - - moviepy + - moviepy<2.0.0 - tqdm - pytest - pytest-cov diff --git a/.github/unittest/linux_libs/scripts_robohive/environment.yml b/.github/unittest/linux_libs/scripts_robohive/environment.yml index 873d19ad0c8..e4366928c02 100644 --- a/.github/unittest/linux_libs/scripts_robohive/environment.yml +++ b/.github/unittest/linux_libs/scripts_robohive/environment.yml @@ -11,7 +11,7 @@ dependencies: - future - cloudpickle - pygame - - moviepy + - moviepy<2.0.0 - tqdm - pytest - pytest-cov diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml index 042a447a1e1..d55b2091483 100644 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml @@ -10,7 +10,7 @@ dependencies: - cloudpickle - gym[atari]==0.13 - pygame - - moviepy + - moviepy<2.0.0 - tqdm - pytest - pytest-cov diff --git a/.github/unittest/linux_sota/scripts/environment.yml b/.github/unittest/linux_sota/scripts/environment.yml index f7dddbc5e3c..12b16ccb54e 100644 --- a/.github/unittest/linux_sota/scripts/environment.yml +++ b/.github/unittest/linux_sota/scripts/environment.yml @@ -9,7 +9,7 @@ dependencies: - future - cloudpickle - pygame - - moviepy + - moviepy<2.0.0 - tqdm - pytest - pytest-cov diff --git a/README.md b/README.md index abcf7349192..633637ac462 100644 --- a/README.md +++ b/README.md @@ -921,7 +921,7 @@ make of torchrl: pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher # rendering -pip3 install moviepy +pip3 install "moviepy<2.0.0" # deepmind control suite pip3 install dm_control diff --git a/setup.py b/setup.py index 75a2486815e..823ec307052 100644 --- a/setup.py +++ b/setup.py @@ -208,7 +208,7 @@ def _main(argv): ], "dm_control": ["dm_control"], "gym_continuous": ["gymnasium<1.0", "mujoco"], - "rendering": ["moviepy"], + "rendering": ["moviepy<2.0.0"], "tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"], "utils": [ "tensorboard", diff --git a/sota-check/README.md b/sota-check/README.md index 0529e091b50..f67245bd851 100644 --- a/sota-check/README.md +++ b/sota-check/README.md @@ -26,7 +26,7 @@ export MUJOCO_GL=egl conda create -n rl-sota-bench python=3.10 -y conda install anaconda::libglu -y pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -pip3 install "gymnasium[accept-rom-license,atari,mujoco]" vmas tqdm wandb pygame moviepy imageio submitit hydra-core transformers +pip3 install "gymnasium[accept-rom-license,atari,mujoco]" vmas tqdm wandb pygame "moviepy<2.0.0" imageio submitit hydra-core transformers cd /path/to/tensordict python setup.py develop diff --git a/sota-implementations/multiagent/README.md b/sota-implementations/multiagent/README.md index 1f9921e6bab..2adacca74af 100644 --- a/sota-implementations/multiagent/README.md +++ b/sota-implementations/multiagent/README.md @@ -27,7 +27,7 @@ Install vmas and dependencies: ```bash pip install vmas -pip install wandb moviepy +pip install wandb "moviepy<2.0.0" pip install hydra-core ```