diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh index d5bb8695c44..05eb63c2b51 100755 --- a/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh @@ -23,6 +23,7 @@ conda deactivate && conda activate ./env python -c "import mlagents_envs" python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestUnityMLAgents --runslow +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_transforms.py --instafail -v --durations 200 --capture no -k test_transform_env[unity] coverage combine coverage xml -i diff --git a/test/mocking_classes.py b/test/mocking_classes.py index eb517429c08..1b6b5ad1663 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import Dict, List, Optional import torch import torch.nn as nn @@ -24,7 +24,12 @@ from torchrl.data.utils import consolidate_spec from torchrl.envs.common import EnvBase from torchrl.envs.model_based.common import ModelBasedEnvBase -from torchrl.envs.utils import _terminated_or_truncated +from torchrl.envs.utils import ( + _terminated_or_truncated, + check_marl_grouping, + MarlGroupMapType, +) + spec_dict = { "bounded": Bounded, @@ -1057,6 +1062,154 @@ def _step( return tensordict +class MultiAgentCountingEnv(EnvBase): + """A multi-agent env that is done after a given number of steps. + + All agents have identical specs. + + The count is incremented by 1 on each step. + + """ + + def __init__( + self, + n_agents: int, + group_map: MarlGroupMapType + | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + max_steps: int = 5, + start_val: int = 0, + **kwargs, + ): + super().__init__(**kwargs) + self.max_steps = max_steps + self.start_val = start_val + self.n_agents = n_agents + self.agent_names = [f"agent_{idx}" for idx in range(n_agents)] + + if isinstance(group_map, MarlGroupMapType): + group_map = group_map.get_group_map(self.agent_names) + check_marl_grouping(group_map, self.agent_names) + + self.group_map = group_map + + observation_specs = {} + reward_specs = {} + done_specs = {} + action_specs = {} + + for group_name, agents in group_map.items(): + observation_specs[group_name] = {} + reward_specs[group_name] = {} + done_specs[group_name] = {} + action_specs[group_name] = {} + + for agent_name in agents: + observation_specs[group_name][agent_name] = Composite( + observation=Unbounded( + ( + *self.batch_size, + 3, + 4, + ), + dtype=torch.float32, + device=self.device, + ), + shape=self.batch_size, + device=self.device, + ) + reward_specs[group_name][agent_name] = Composite( + reward=Unbounded( + ( + *self.batch_size, + 1, + ), + device=self.device, + ), + shape=self.batch_size, + device=self.device, + ) + done_specs[group_name][agent_name] = Composite( + done=Categorical( + 2, + dtype=torch.bool, + shape=( + *self.batch_size, + 1, + ), + device=self.device, + ), + shape=self.batch_size, + device=self.device, + ) + action_specs[group_name][agent_name] = Composite( + action=Binary(n=1, shape=[*self.batch_size, 1], device=self.device), + shape=self.batch_size, + device=self.device, + ) + + self.observation_spec = Composite(observation_specs) + self.reward_spec = Composite(reward_specs) + self.done_spec = Composite(done_specs) + self.action_spec = Composite(action_specs) + self.register_buffer( + "count", + torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int), + ) + + def _set_seed(self, seed: Optional[int]): + torch.manual_seed(seed) + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + if tensordict is not None and "_reset" in tensordict.keys(): + _reset = tensordict.get("_reset") + self.count[_reset] = self.start_val + else: + self.count[:] = self.start_val + + source = {} + for group_name, agents in self.group_map.items(): + source[group_name] = {} + for agent_name in agents: + source[group_name][agent_name] = TensorDict( + source={ + "observation": torch.rand( + (*self.batch_size, 3, 4), device=self.device + ), + "done": self.count > self.max_steps, + "terminated": self.count > self.max_steps, + }, + batch_size=self.batch_size, + device=self.device, + ) + + tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device) + return tensordict + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + self.count += 1 + source = {} + for group_name, agents in self.group_map.items(): + source[group_name] = {} + for agent_name in agents: + source[group_name][agent_name] = TensorDict( + source={ + "observation": torch.rand( + (*self.batch_size, 3, 4), device=self.device + ), + "done": self.count > self.max_steps, + "terminated": self.count > self.max_steps, + "reward": torch.zeros_like(self.count, dtype=torch.float), + }, + batch_size=self.batch_size, + device=self.device, + ) + tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device) + return tensordict + + class IncrementingEnv(CountingEnv): # Same as CountingEnv but always increments the count by 1 regardless of the action. def _step( diff --git a/test/test_transforms.py b/test/test_transforms.py index 8b2ada8c93a..171ea93e81a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -44,6 +44,7 @@ IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, + MultiAgentCountingEnv, MultiKeyCountingEnv, MultiKeyCountingEnvPolicy, NestedCountingEnv, @@ -69,6 +70,7 @@ IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, + MultiAgentCountingEnv, MultiKeyCountingEnv, MultiKeyCountingEnvPolicy, NestedCountingEnv, @@ -132,6 +134,7 @@ SerialEnv, SignTransform, SqueezeTransform, + Stack, StepCounter, TargetReturn, TensorDictPrimer, @@ -139,12 +142,14 @@ ToTensorImage, TrajCounter, TransformedEnv, + UnityMLAgentsEnv, UnsqueezeTransform, VC1Transform, VIPTransform, ) from torchrl.envs.libs.dm_control import _has_dm_control from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend +from torchrl.envs.libs.unity_mlagents import _has_unity_mlagents from torchrl.envs.transforms import VecNorm from torchrl.envs.transforms.r3m import _R3MNet from torchrl.envs.transforms.rlhf import KLRewardTransform @@ -157,7 +162,7 @@ ) from torchrl.envs.transforms.vc1 import _has_vc from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform -from torchrl.envs.utils import check_env_specs, step_mdp +from torchrl.envs.utils import check_env_specs, MarlGroupMapType, step_mdp from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal from torchrl.modules.utils import get_primers_from_module @@ -2147,6 +2152,379 @@ def test_transform_no_env(self, device, batch): pytest.skip("TrajCounter cannot be called without env") +class TestStack(TransformBase): + def test_single_trans_env_check(self): + t = Stack( + in_keys=["observation", "observation_orig"], + out_key="observation_out", + dim=-1, + del_keys=False, + ) + env = TransformedEnv(ContinuousActionVecMockEnv(), t) + check_env_specs(env) + + def test_serial_trans_env_check(self): + def make_env(): + t = Stack( + in_keys=["observation", "observation_orig"], + out_key="observation_out", + dim=-1, + del_keys=False, + ) + return TransformedEnv(ContinuousActionVecMockEnv(), t) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + def make_env(): + t = Stack( + in_keys=["observation", "observation_orig"], + out_key="observation_out", + dim=-1, + del_keys=False, + ) + return TransformedEnv(ContinuousActionVecMockEnv(), t) + + env = maybe_fork_ParallelEnv(2, make_env) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + def test_trans_serial_env_check(self): + t = Stack( + in_keys=["observation", "observation_orig"], + out_key="observation_out", + dim=-2, + del_keys=False, + ) + + env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), t) + check_env_specs(env) + + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): + t = Stack( + in_keys=["observation", "observation_orig"], + out_key="observation_out", + dim=-2, + del_keys=False, + ) + + env = TransformedEnv(maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), t) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("del_keys", [True, False]) + def test_transform_del_keys(self, del_keys): + td_orig = TensorDict( + { + "group_0": TensorDict( + { + "agent_0": TensorDict({"obs": torch.randn(10)}), + "agent_1": TensorDict({"obs": torch.randn(10)}), + } + ), + "group_1": TensorDict( + { + "agent_2": TensorDict({"obs": torch.randn(10)}), + "agent_3": TensorDict({"obs": torch.randn(10)}), + } + ), + } + ) + t = Stack( + in_keys=[ + ("group_0", "agent_0", "obs"), + ("group_0", "agent_1", "obs"), + ("group_1", "agent_2", "obs"), + ("group_1", "agent_3", "obs"), + ], + out_key="observations", + del_keys=del_keys, + ) + td = td_orig.clone() + t(td) + keys = td.keys(include_nested=True) + if del_keys: + assert ("group_0",) not in keys + assert ("group_0", "agent_0", "obs") not in keys + assert ("group_0", "agent_1", "obs") not in keys + assert ("group_1", "agent_2", "obs") not in keys + assert ("group_1", "agent_3", "obs") not in keys + else: + assert ("group_0", "agent_0", "obs") in keys + assert ("group_0", "agent_1", "obs") in keys + assert ("group_1", "agent_2", "obs") in keys + assert ("group_1", "agent_3", "obs") in keys + + assert ("observations",) in keys + + def _test_transform_no_env_tensor(self, compose=False): + td_orig = TensorDict( + { + "key1": torch.rand(1, 3), + "key2": torch.rand(1, 3), + "key3": torch.rand(1, 3), + }, + [1], + ) + td = td_orig.clone() + t = Stack( + in_keys=[("key1",), ("key2",)], + out_key=("stacked",), + dim=-2, + ) + if compose: + t = Compose(t) + + td = t(td) + + assert ("key1",) not in td.keys() + assert ("key2",) not in td.keys() + assert ("key3",) in td.keys() + assert ("stacked",) in td.keys() + + assert td["stacked"].shape == torch.Size([1, 2, 3]) + assert (td["stacked"][:, 0] == td_orig["key1"]).all() + assert (td["stacked"][:, 1] == td_orig["key2"]).all() + + td = t.inv(td) + assert (td == td_orig).all() + + def _test_transform_no_env_tensordict(self, compose=False): + def gen_value(): + return TensorDict( + { + "a": torch.rand(3), + "b": torch.rand(2, 4), + } + ) + + td_orig = TensorDict( + { + "key1": gen_value(), + "key2": gen_value(), + "key3": gen_value(), + }, + [], + ) + td = td_orig.clone() + t = Stack( + in_keys=[("key1",), ("key2",)], + out_key=("stacked",), + dim=0, + allow_positive_dim=True, + ) + if compose: + t = Compose(t) + td = t(td) + + assert ("key1",) not in td.keys() + assert ("key2",) not in td.keys() + assert ("stacked", "a") in td.keys(include_nested=True) + assert ("stacked", "b") in td.keys(include_nested=True) + assert ("key3",) in td.keys() + + assert td["stacked", "a"].shape == torch.Size([2, 3]) + assert td["stacked", "b"].shape == torch.Size([2, 2, 4]) + assert (td["stacked"][0] == td_orig["key1"]).all() + assert (td["stacked"][1] == td_orig["key2"]).all() + assert (td["key3"] == td_orig["key3"]).all() + + td = t.inv(td) + assert (td == td_orig).all() + + @pytest.mark.parametrize("datatype", ["tensor", "tensordict"]) + def test_transform_no_env(self, datatype): + if datatype == "tensor": + self._test_transform_no_env_tensor() + + elif datatype == "tensordict": + self._test_transform_no_env_tensordict() + + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + @pytest.mark.parametrize("datatype", ["tensor", "tensordict"]) + def test_transform_compose(self, datatype): + if datatype == "tensor": + self._test_transform_no_env_tensor(compose=True) + + elif datatype == "tensordict": + self._test_transform_no_env_tensordict(compose=True) + + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + @pytest.mark.parametrize("envtype", ["mock", "unity"]) + def test_transform_env(self, envtype): + if envtype == "mock": + base_env = MultiAgentCountingEnv( + n_agents=5, + ) + rollout_len = 6 + t = Stack( + in_keys=[ + ("agents", "agent_0"), + ("agents", "agent_2"), + ("agents", "agent_3"), + ], + out_key="stacked_agents", + ) + + elif envtype == "unity": + if not _has_unity_mlagents: + raise pytest.skip("mlagents not installed") + base_env = UnityMLAgentsEnv( + registered_name="3DBall", + no_graphics=True, + group_map=MarlGroupMapType.ALL_IN_ONE_GROUP, + ) + rollout_len = 200 + t = Stack( + in_keys=[("agents", f"agent_{idx}") for idx in range(12)], + out_key="stacked_agents", + ) + + try: + env = TransformedEnv(base_env, t) + check_env_specs(env) + + if envtype == "mock": + base_env.set_seed(123) + td_orig = base_env.reset() + if envtype == "mock": + env.set_seed(123) + td = env.reset() + + td_keys = td.keys(include_nested=True) + + if envtype == "mock": + assert ("agents", "agent_0") not in td_keys + assert ("agents", "agent_2") not in td_keys + assert ("agents", "agent_3") not in td_keys + assert ("agents", "agent_1") in td_keys + assert ("agents", "agent_4") in td_keys + assert ("stacked_agents",) in td_keys + + assert (td["stacked_agents"][0] == td_orig["agents", "agent_0"]).all() + assert (td["stacked_agents"][1] == td_orig["agents", "agent_2"]).all() + assert (td["stacked_agents"][2] == td_orig["agents", "agent_3"]).all() + assert (td["agents", "agent_1"] == td_orig["agents", "agent_1"]).all() + assert (td["agents", "agent_4"] == td_orig["agents", "agent_4"]).all() + else: + assert ("agents",) not in td_keys + assert ("stacked_agents",) in td_keys + assert td["stacked_agents"].shape[0] == 12 + + assert ("agents",) not in env.full_action_spec.keys(include_nested=True) + assert ("stacked_agents",) in env.full_action_spec.keys( + include_nested=True + ) + + td = env.step(env.full_action_spec.rand()) + td = env.rollout(rollout_len) + + if envtype == "mock": + assert td["next", "stacked_agents", "done"].shape == torch.Size( + [6, 3, 1] + ) + assert not (td["next", "stacked_agents", "done"][:-1]).any() + assert (td["next", "stacked_agents", "done"][-1]).all() + finally: + base_env.close() + + def test_transform_model(self): + t = Stack( + in_keys=[("next", "observation"), ("observation",)], + out_key="observation_out", + dim=-2, + del_keys=True, + ) + model = nn.Sequential(t, nn.Identity()) + td = TensorDict( + {("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, [] + ) + td = model(td) + assert "observation_out" in td.keys() + assert "observation" not in td.keys() + assert ("next", "observation") not in td.keys(True) + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + t = Stack( + in_keys=[("next", "observation"), "observation"], + out_key="observation_out", + dim=-2, + del_keys=True, + ) + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform(t) + td = TensorDict( + { + "observation": TensorDict({"stuff": torch.randn(3, 4)}, [3, 4]), + "next": TensorDict( + {"observation": TensorDict({"stuff": torch.randn(3, 4)}, [3, 4])}, + [], + ), + }, + [], + ).expand(10) + rb.extend(td) + td = rb.sample(2) + assert "observation_out" in td.keys() + assert "observation" not in td.keys() + assert ("next", "observation") not in td.keys(True) + + def test_transform_inverse(self): + td_orig = TensorDict( + { + "stacked": torch.rand(1, 2, 3), + "key3": torch.rand(1, 3), + }, + [1], + ) + td = td_orig.clone() + t = Stack( + in_keys=[("key1",), ("key2",)], + out_key=("stacked",), + dim=1, + allow_positive_dim=True, + ) + + td = t.inv(td) + + assert ("key1",) in td.keys() + assert ("key2",) in td.keys() + assert ("key3",) in td.keys() + assert ("stacked",) not in td.keys() + assert (td["key1"] == td_orig["stacked"][:, 0]).all() + assert (td["key2"] == td_orig["stacked"][:, 1]).all() + + td = t(td) + assert (td == td_orig).all() + + # Check that if `out_key` is not in the tensordict, + # then the inverse transform does nothing. + t = Stack( + in_keys=[("key1",), ("key2",)], + out_key=("sacked",), + dim=1, + allow_positive_dim=True, + ) + td = t.inv(td) + assert (td == td_orig).all() + + class TestCatTensors(TransformBase): @pytest.mark.parametrize("append", [True, False]) def test_cattensors_empty(self, append): diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 4cfb00cc307..36e4ec1a908 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -87,6 +87,7 @@ SelectTransform, SignTransform, SqueezeTransform, + Stack, StepCounter, TargetReturn, TensorDictPrimer, diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py index 95c2460bc83..5aeabc4d0aa 100644 --- a/torchrl/envs/libs/unity_mlagents.py +++ b/torchrl/envs/libs/unity_mlagents.py @@ -363,12 +363,12 @@ def _make_td_out(self, tensordict_in, is_reset=False): # Add rewards if not is_reset: source[group_name][agent_name]["reward"] = torch.tensor( - steps.reward[steps_idx], + [steps.reward[steps_idx]], device=self.device, dtype=torch.float32, ) source[group_name][agent_name]["group_reward"] = torch.tensor( - steps.group_reward[steps_idx], + [steps.group_reward[steps_idx]], device=self.device, dtype=torch.float32, ) diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index bccbd9a4543..77f6ecc03bf 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -48,6 +48,7 @@ SelectTransform, SignTransform, SqueezeTransform, + Stack, StepCounter, TargetReturn, TensorDictPrimer, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7bdd25591cd..d5e25efc28d 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4323,6 +4323,198 @@ def __repr__(self) -> str: ) +class Stack(Transform): + """Stacks tensors and tensordicts. + + This transform is useful for environments that have multiple agents with + identical specs under different keys. The specs and tensordicts for the + agents can be stacked together under a shared key, in order to run MARL + algorithms that expect the tensors for observations, rewards, etc. to + contain batched data for all the agents. + + Args: + in_keys (sequence of NestedKey): keys to be stacked. If `None` or not + provided, the keys will be retrieved from the group map of the + environment the first time the transform is used. This behavior will + only work if a parent is set. + out_key (NestedKey): key of the resulting tensor. + dim (int, optional): dimension to insert. Default is ``-1``. + allow_positive_dim (bool, optional): if ``True``, positive dimensions + are accepted. Defaults to ``False``, ie. non-negative dimensions are + not permitted. + + Keyword Args: + del_keys (bool, optional): if ``True``, the input values will be deleted + after stacking. Default is ``True``. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.envs import Stack + >>> td = TensorDict({"key1": torch.zeros(3), "key2": torch.ones(3)}, []) + >>> td + TensorDict( + fields={ + key1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + key2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> transform = Stack(in_keys=["key1", "key2"], out_key="out", dim=-2) + >>> transform(td) + TensorDict( + fields={ + out: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> td["out"] + tensor([[0., 0., 0.], + [1., 1., 1.]]) + + >>> agent_0 = TensorDict({"obs": torch.rand(4, 5), "reward": torch.zeros(1)}) + >>> agent_1 = TensorDict({"obs": torch.rand(4, 5), "reward": torch.zeros(1)}) + >>> td = TensorDict({"agent_0": agent_0, "agent_1": agent_1}) + >>> transform = Stack(in_keys=["agent_0", "agent_1"], out_key="agents") + >>> transform(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + obs: Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + """ + + invertible = True + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_key: NestedKey, + dim: int = -1, + allow_positive_dim: bool = False, + *, + del_keys: bool = True, + ): + if not allow_positive_dim and dim >= 0: + raise ValueError( + "dim should be negative to accommodate for envs of different " + "batch_sizes. If you need dim to be positive, set " + "allow_positive_dim=True." + ) + super(Stack, self).__init__( + in_keys=in_keys, + out_keys=[out_key], + in_keys_inv=[out_key], + out_keys_inv=copy(in_keys), + ) + + for in_key in self.in_keys: + if len(in_key) == len(self.out_keys[0]): + if all(k1 == k2 for k1, k2 in zip(in_key, self.out_keys[0])): + raise ValueError(f"{self}: out_key cannot be in in_keys") + parent_keys = [] + for key in self.in_keys: + if isinstance(key, (list, tuple)): + for parent_level in range(1, len(key)): + parent_key = tuple(key[:-parent_level]) + if parent_key not in parent_keys: + parent_keys.append(parent_key) + self._maybe_del_parent_keys = sorted(parent_keys, key=len, reverse=True) + self.dim = dim + self._del_keys = del_keys + self._keys_to_exclude = None + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + values = [] + for in_key in self.in_keys: + value = tensordict.get(in_key, default=None) + if value is not None: + values.append(value) + elif not self.missing_tolerance: + raise KeyError( + f"{self}: '{in_key}' not found in tensordict {tensordict}" + ) + + out_tensor = torch.stack(values, dim=self.dim) + tensordict.set(self.out_keys[0], out_tensor) + if self._del_keys: + tensordict.exclude(*self.in_keys, inplace=True) + for parent_key in self._maybe_del_parent_keys: + if len(tensordict[parent_key].keys()) == 0: + tensordict.exclude(parent_key, inplace=True) + return tensordict + + forward = _call + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + if self.in_keys_inv[0] not in tensordict.keys(include_nested=True): + return tensordict + values = torch.unbind(tensordict[self.in_keys_inv[0]], dim=self.dim) + for value, out_key_inv in _zip_strict(values, self.out_keys_inv): + tensordict = tensordict.set(out_key_inv, value) + return tensordict.exclude(self.in_keys_inv[0]) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + with _set_missing_tolerance(self, True): + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset + + def _transform_spec(self, spec: TensorSpec) -> TensorSpec: + if not isinstance(spec, Composite): + raise TypeError(f"{self}: Only specs of type Composite can be transformed") + + spec_keys = spec.keys(include_nested=True) + keys_to_stack = [key for key in spec_keys if key in self.in_keys] + specs_to_stack = [spec[key] for key in keys_to_stack] + + if len(specs_to_stack) == 0: + return spec + + stacked_specs = torch.stack(specs_to_stack, dim=self.dim) + spec.set(self.out_keys[0], stacked_specs) + + if self._del_keys: + for key in keys_to_stack: + del spec[key] + for parent_key in self._maybe_del_parent_keys: + if len(spec[parent_key]) == 0: + del spec[parent_key] + + return spec + + def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + self._transform_spec(input_spec["full_state_spec"]) + self._transform_spec(input_spec["full_action_spec"]) + return input_spec + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + return self._transform_spec(observation_spec) + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + return self._transform_spec(reward_spec) + + def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec: + return self._transform_spec(done_spec) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"in_keys={self.in_keys}, " + f"out_key={self.out_keys[0]}, " + f"dim={self.dim}" + ")" + ) + + class DiscreteActionProjection(Transform): """Projects discrete actions from a high dimensional space to a low dimensional space.