From 121bb41bc75d568a6d220d14081488a321721fbc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 19 Nov 2024 21:34:09 +0000 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/cql/utils.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 5d44e7d9d18..a6a6d2311fa 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -6,6 +6,7 @@ import torch.nn import torch.optim +from tensordict import TensorDict, TensorDictParams from tensordict.nn import TensorDictModule, TensorDictSequential from tensordict.nn.distributions import NormalParamExtractor @@ -216,12 +217,19 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"): in_keys=["loc", "scale"], spec=action_spec, distribution_class=TanhNormal, - distribution_kwargs={ - "low": torch.as_tensor(action_spec.space.low, device=device), - "high": torch.as_tensor(action_spec.space.high, device=device), - "tanh_loc": False, - "safe_tanh": not cfg.compile.compile, - }, + # Wrapping the kwargs in a TensorDictParams such that these items are + # send to device when necessary + distribution_kwargs=TensorDictParams( + TensorDict( + { + "low": torch.as_tensor(action_spec.space.low, device=device), + "high": torch.as_tensor(action_spec.space.high, device=device), + "tanh_loc": False, + "safe_tanh": not cfg.compile.compile, + } + ), + no_convert=True, + ), default_interaction_type=ExplorationType.RANDOM, )