Skip to content

Commit

Permalink
[Quality] Better TD construction in codebase
Browse files Browse the repository at this point in the history
ghstack-source-id: 9e280d9d7d4a735e5055beb0450d933547530e55
Pull Request resolved: #2565
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent 9f8f77c commit a4c1ee3
Show file tree
Hide file tree
Showing 28 changed files with 83 additions and 97 deletions.
2 changes: 1 addition & 1 deletion examples/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def log(self, model):

class TrainLogger:
def __init__(self, size: int, log_interval: int, logger: Logger):
self.data = TensorDict({}, [size])
self.data = TensorDict(batch_size=[size])
self.counter = 0
self.log_interval = log_interval
self.logger = logger
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def main(cfg: "DictConfig"): # noqa: F821
}
)

losses = TensorDict({}, batch_size=[num_mini_batches])
losses = TensorDict(batch_size=[num_mini_batches])
training_start = time.time()

# Compute GAE
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def main(cfg: "DictConfig"): # noqa: F821
}
)

losses = TensorDict({}, batch_size=[num_mini_batches])
losses = TensorDict(batch_size=[num_mini_batches])
training_start = time.time()

# Compute GAE
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
log_loss_td = TensorDict({}, [num_updates])
log_loss_td = TensorDict(batch_size=[num_updates])
for j in range(num_updates):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample()
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/impala/impala_multi_node_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def main(cfg: "DictConfig"): # noqa: F821
logger.log_scalar(key, value, collected_frames)
continue

losses = TensorDict({}, batch_size=[sgd_updates])
losses = TensorDict(batch_size=[sgd_updates])
training_start = time.time()
for j in range(sgd_updates):

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/impala/impala_multi_node_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def main(cfg: "DictConfig"): # noqa: F821
logger.log_scalar(key, value, collected_frames)
continue

losses = TensorDict({}, batch_size=[sgd_updates])
losses = TensorDict(batch_size=[sgd_updates])
training_start = time.time()
for j in range(sgd_updates):

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/impala/impala_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def main(cfg: "DictConfig"): # noqa: F821
logger.log_scalar(key, value, collected_frames)
continue

losses = TensorDict({}, batch_size=[sgd_updates])
losses = TensorDict(batch_size=[sgd_updates])
training_start = time.time()
for j in range(sgd_updates):

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
cfg_optim_max_grad_norm = cfg.optim.max_grad_norm
cfg.loss.clip_epsilon = cfg_loss_clip_epsilon
losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches])
losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches])

for i, data in enumerate(collector):

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cfg_loss_clip_epsilon = cfg.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches])
losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches])

for i, data in enumerate(collector):

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
losses = TensorDict({}, batch_size=[num_updates])
losses = TensorDict(batch_size=[num_updates])
for i in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
Expand Down
4 changes: 2 additions & 2 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def _reset(self, tensordict: TensorDictBase = None) -> TensorDictBase:
self.counter += 1
state = torch.zeros(self.size) + self.counter
if tensordict is None:
tensordict = TensorDict({}, self.batch_size, device=self.device)
tensordict = TensorDict(batch_size=self.batch_size, device=self.device)
tensordict = tensordict.empty().set(self.out_key, self._get_out_obs(state))
tensordict = tensordict.set(self._out_key, self._get_out_obs(state))
tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool))
Expand Down Expand Up @@ -595,7 +595,7 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
self.step_count = 0
# state = torch.zeros(self.size) + self.counter
if tensordict is None:
tensordict = TensorDict({}, self.batch_size, device=self.device)
tensordict = TensorDict(batch_size=self.batch_size, device=self.device)

tensordict = tensordict.empty()
tensordict.update(self.observation_spec.rand())
Expand Down
10 changes: 5 additions & 5 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,7 @@ def test_steptensordict(
tds[1]["but", "not", "this", "one"] = torch.ones(2)
tds[0]["next", "this", "one"] = torch.ones(2) * 2
tensordict = LazyStackedTensorDict.lazy_stack(tds, 0)
next_tensordict = TensorDict({}, [4]) if has_out else None
next_tensordict = TensorDict(batch_size=[4]) if has_out else None
if has_out and lazy_stack:
next_tensordict = LazyStackedTensorDict.lazy_stack(
next_tensordict.unbind(0), 0
Expand Down Expand Up @@ -1550,9 +1550,9 @@ def test_nested(
nested_key = ("data",)
td = TensorDict(
{
nested_key: TensorDict({}, nested_batch_size),
nested_key: TensorDict(batch_size=nested_batch_size),
"next": {
nested_key: TensorDict({}, nested_batch_size),
nested_key: TensorDict(batch_size=nested_batch_size),
},
},
td_batch_size,
Expand Down Expand Up @@ -1670,7 +1670,7 @@ def test_nested_partially(
# Nested only in root
td = TensorDict(
{
nested_key: TensorDict({}, nested_batch_size),
nested_key: TensorDict(batch_size=nested_batch_size),
"next": {},
},
td_batch_size,
Expand Down Expand Up @@ -1711,7 +1711,7 @@ def test_nested_partially(
# Nested only in next
td = TensorDict(
{
"next": {nested_key: TensorDict({}, nested_batch_size)},
"next": {nested_key: TensorDict(batch_size=nested_batch_size)},
},
td_batch_size,
)
Expand Down
4 changes: 2 additions & 2 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def test_CEM_model_free_env(self, device, batch_size, seed=1):
num_candidates=100,
top_k=2,
)
td = env.reset(TensorDict({}, batch_size=batch_size).to(device))
td = env.reset(TensorDict(batch_size=batch_size).to(device))
td_copy = td.clone()
td = planner(td)
assert (
Expand Down Expand Up @@ -408,7 +408,7 @@ def test_MPPI(self, device, batch_size, seed=1):
num_candidates=100,
top_k=2,
)
td = env.reset(TensorDict({}, batch_size=batch_size).to(device))
td = env.reset(TensorDict(batch_size=batch_size).to(device))
td_copy = td.clone()
td = planner(td)
assert (
Expand Down
2 changes: 1 addition & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3317,7 +3317,7 @@ def _make_storage(self, storage_type, data_type):
return LazyMemmapStorage(max_size=100)
if storage_type is TensorStorage:
if data_type is TensorDict:
return TensorStorage(TensorDict({}, [100]))
return TensorStorage(TensorDict(batch_size=[100]))
elif data_type is torch.Tensor:
return TensorStorage(torch.zeros(100))
else:
Expand Down
44 changes: 18 additions & 26 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3945,7 +3945,7 @@ def test_nested(self, skip=4):
def test_transform_model(self):
t = FrameSkipTransform(2)
t = nn.Sequential(t, nn.Identity())
tensordict = TensorDict({}, [])
tensordict = TensorDict()
with pytest.raises(
RuntimeError,
match="FrameSkipTransform can only be used when appended to a transformed env",
Expand Down Expand Up @@ -4252,15 +4252,15 @@ def test_transform_compose(self):

def test_transform_model(self):
t = nn.Sequential(NoopResetEnv(), nn.Identity())
td = TensorDict({}, [])
td = TensorDict()
t(td)

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, rbclass):
t = NoopResetEnv()
rb = rbclass(storage=LazyTensorStorage(10))
rb.append_transform(t)
td = TensorDict({}, [10])
td = TensorDict(batch_size=[10])
rb.extend(td)
rb.sample(1)

Expand Down Expand Up @@ -6917,7 +6917,7 @@ def test_transform_no_env(self):
def test_transform_model(self):
t = TensorDictPrimer(mykey=Unbounded([3]))
model = nn.Sequential(t, nn.Identity())
td = TensorDict({}, [])
td = TensorDict()
model(td)
assert "mykey" in td.keys()

Expand Down Expand Up @@ -7507,7 +7507,7 @@ def test_transform_model(self):
action_dim = 5
t = gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,))
model = nn.Sequential(t, nn.Identity())
td = TensorDict({}, [])
td = TensorDict()
model(td)
assert "_eps_gSDE" in td.keys()
assert (td["_eps_gSDE"] != 0.0).all()
Expand Down Expand Up @@ -9736,7 +9736,7 @@ def test_transform_model(self):
with pytest.raises(
NotImplementedError, match="InitTracker cannot be executed without a parent"
):
td = TensorDict({}, [])
td = TensorDict()
chain = nn.Sequential(InitTracker())
chain(td)

Expand Down Expand Up @@ -10169,7 +10169,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
def test_transform_no_env(self):
t = ActionMask()
with pytest.raises(RuntimeError, match="parent cannot be None"):
t._call(TensorDict({}, []))
t._call(TensorDict())

def test_transform_compose(self):
env = self._env_class()
Expand Down Expand Up @@ -10197,7 +10197,7 @@ def test_transform_env(self):
def test_transform_model(self):
t = ActionMask()
with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))):
t(TensorDict({}, []))
t(TensorDict())

def test_transform_rb(self):
t = ActionMask()
Expand Down Expand Up @@ -10526,18 +10526,12 @@ def make_env():

def test_transform_no_env(self):
t = DeviceCastTransform("cpu:1", "cpu:0")
assert t._call(TensorDict({}, [], device="cpu:0")).device == torch.device(
"cpu:1"
)
assert t._call(TensorDict(device="cpu:0")).device == torch.device("cpu:1")

def test_transform_compose(self):
t = Compose(DeviceCastTransform("cpu:1", "cpu:0"))
assert t._call(TensorDict({}, [], device="cpu:0")).device == torch.device(
"cpu:1"
)
assert t._inv_call(TensorDict({}, [], device="cpu:1")).device == torch.device(
"cpu:0"
)
assert t._call(TensorDict(device="cpu:0")).device == torch.device("cpu:1")
assert t._inv_call(TensorDict(device="cpu:1")).device == torch.device("cpu:0")

def test_transform_env(self):
env = ContinuousActionVecMockEnv(device="cpu:0")
Expand All @@ -10550,7 +10544,7 @@ def test_transform_env(self):
def test_transform_model(self):
t = Compose(DeviceCastTransform("cpu:1", "cpu:0"))
m = nn.Sequential(t)
assert t(TensorDict({}, [], device="cpu:0")).device == torch.device("cpu:1")
assert t(TensorDict(device="cpu:0")).device == torch.device("cpu:1")

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
@pytest.mark.parametrize("storage", [TensorStorage, LazyTensorStorage])
Expand All @@ -10574,9 +10568,7 @@ def test_transform_rb(self, rbclass, storage):

def test_transform_inverse(self):
t = DeviceCastTransform("cpu:1", "cpu:0")
assert t._inv_call(TensorDict({}, [], device="cpu:1")).device == torch.device(
"cpu:0"
)
assert t._inv_call(TensorDict(device="cpu:1")).device == torch.device("cpu:0")


class TestPermuteTransform(TransformBase):
Expand Down Expand Up @@ -10804,12 +10796,12 @@ def make():
def test_transform_no_env(self):
t = EndOfLifeTransform()
with pytest.raises(RuntimeError, match=t.NO_PARENT_ERR.format(type(t))):
t._step(TensorDict({}, []), TensorDict({}, []))
t._step(TensorDict(), TensorDict())

def test_transform_compose(self):
t = EndOfLifeTransform()
with pytest.raises(RuntimeError, match=t.NO_PARENT_ERR.format(type(t))):
Compose(t)._step(TensorDict({}, []), TensorDict({}, []))
Compose(t)._step(TensorDict(), TensorDict())

@pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")])
@pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")])
Expand Down Expand Up @@ -10838,7 +10830,7 @@ def test_transform_env(self, eol_key, lives_key):
def test_transform_model(self):
t = EndOfLifeTransform()
with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))):
nn.Sequential(t)(TensorDict({}, []))
nn.Sequential(t)(TensorDict())

def test_transform_rb(self):
pass
Expand Down Expand Up @@ -11286,7 +11278,7 @@ def _reset(self, tensordict):

def _step(self, tensordict):
return (
TensorDict({}, batch_size=[])
TensorDict()
.update(self.observation_spec.rand())
.update(self.full_done_spec.zero())
.update(self.full_reward_spec.rand())
Expand Down Expand Up @@ -11378,7 +11370,7 @@ def test_transform_inverse(self):
t.inv(td)
assert len(td.keys()) != 0
env = TransformedEnv(self.DummyEnv(), RemoveEmptySpecs())
td2 = env.transform.inv(TensorDict({}, []))
td2 = env.transform.inv(TensorDict())
assert ("state", "sub") in td2.keys(True)


Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def _download_and_proc_split(
@classmethod
def _preproc_run(cls, path, gz_files, run):
files = gz_files[run]
td = TensorDict({}, [])
td = TensorDict()
path = Path(path)
for file in files:
name = str(Path(file).parts[-1]).split(".")[0]
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/minari_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _download_and_preproc(self):
minari.download_dataset(dataset_id=self.dataset_id)
parent_dir = Path(tmpdir) / self.dataset_id / "data"

td_data = TensorDict({}, [])
td_data = TensorDict()
total_steps = 0
torchrl_logger.info("first read through data to create data structure...")
h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5")
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/roboset.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _download_and_preproc(self):
return self._preproc_h5(h5_data_files)

def _preproc_h5(self, h5_data_files):
td_data = TensorDict({}, [])
td_data = TensorDict()
total_steps = 0
torchrl_logger.info(
f"first read through data files {h5_data_files} to create data structure..."
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,7 +2173,7 @@ def sample(self, storage, batch_size):
[
TensorDict.from_dict(info, batch_dims=samples.ndim - 1)
if info
else TensorDict({}, [])
else TensorDict()
for info in infos
]
)
Expand Down
12 changes: 3 additions & 9 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,7 @@ def load_state_dict(self, state_dict):
if isinstance(elt, torch.Tensor):
self._storage.append(elt)
elif isinstance(elt, (dict, OrderedDict)):
self._storage.append(
TensorDict({}, []).load_state_dict(elt, strict=False)
)
self._storage.append(TensorDict().load_state_dict(elt, strict=False))
else:
raise TypeError(
f"Objects of type {type(elt)} are not supported by ListStorage.load_state_dict"
Expand Down Expand Up @@ -675,9 +673,7 @@ def load_state_dict(self, state_dict):
if is_tensor_collection(self._storage):
self._storage.load_state_dict(_storage, strict=False)
elif self._storage is None:
self._storage = TensorDict({}, []).load_state_dict(
_storage, strict=False
)
self._storage = TensorDict().load_state_dict(_storage, strict=False)
else:
raise RuntimeError(
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}. If your storage is pytree-based, use the dumps/load API instead."
Expand Down Expand Up @@ -1193,9 +1189,7 @@ def load_state_dict(self, state_dict):
"It is preferable to load a storage onto a"
"pre-allocated one whenever possible."
)
self._storage = TensorDict({}, []).load_state_dict(
_storage, strict=False
)
self._storage = TensorDict().load_state_dict(_storage, strict=False)
self._storage.memmap_()
else:
raise RuntimeError(
Expand Down
Loading

1 comment on commit a4c1ee3

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: a4c1ee3 Previous: 9f8f77c Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 38.83322698958793 iter/sec (stddev: 0.15353294800274053) 244.23257474796227 iter/sec (stddev: 0.0007217646505882698) 6.29

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.