Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 20, 2024
2 parents a7f5446 + 0093d10 commit 91c0305
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def train(cfg: "DictConfig"): # noqa: F821
if cfg.loss.mixer_type == "qmix":
mixer = TensorDictModule(
module=QMixer(
state_shape=env.unbatched_observation_spec[
state_shape=env.observation_spec_unbatched[
"agents", "observation"
].shape,
mixing_embed_dim=32,
Expand Down
4 changes: 3 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,9 @@ def _step(
action = tensordict.get(self.action_key)
self.count += action.to(
dtype=torch.int,
device=self.action_spec.device if self.device is None else self.device,
device=self.full_action_spec[self.action_key].device
if self.device is None
else self.device,
)
tensordict = TensorDict(
source={
Expand Down
5 changes: 3 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,11 @@ def __call__(cls, *args, **kwargs):
# we create the done spec by adding a done/terminated entry if one is missing
instance._create_done_specs()
# we access lazy attributed to make sure they're built properly.
# This isn't done in `__init__` because we don't know if supre().__init__
# This isn't done in `__init__` because we don't know if super().__init__
# will be called before or after the specs, batch size etc are set.
_ = instance.done_spec
_ = instance.reward_spec
_ = instance.reward_keys
_ = instance.action_keys
_ = instance.state_spec
if auto_reset:
from torchrl.envs.transforms.transforms import (
Expand Down

0 comments on commit 91c0305

Please sign in to comment.