diff --git a/test/test_rb.py b/test/test_rb.py index 24b33f89795..0e10f534728 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -2200,6 +2200,34 @@ def test_sampler_without_rep_state_dict(self, backend): s = new_replay_buffer.sample(batch_size=1) assert (s.exclude("index") == 0).all() + def test_sampler_without_rep_dumps_loads(self, tmpdir): + d0 = tmpdir + "/save0" + d1 = tmpdir + "/save1" + d2 = tmpdir + "/dump" + replay_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(max_size=100, scratch_dir=d0, device="cpu"), + sampler=SamplerWithoutReplacement(drop_last=True), + batch_size=8, + ) + replay_buffer2 = TensorDictReplayBuffer( + storage=LazyMemmapStorage(max_size=100, scratch_dir=d1, device="cpu"), + sampler=SamplerWithoutReplacement(drop_last=True), + batch_size=8, + ) + td = TensorDict( + {"a": torch.arange(0, 27), ("b", "c"): torch.arange(1, 28)}, batch_size=[27] + ) + replay_buffer.extend(td) + for _ in replay_buffer: + break + replay_buffer.dumps(d2) + replay_buffer2.loads(d2) + assert ( + replay_buffer.sampler._sample_list == replay_buffer2.sampler._sample_list + ).all() + s = replay_buffer2.sample(3) + assert (s["a"] == s["b", "c"] - 1).all() + @pytest.mark.parametrize("drop_last", [False, True]) def test_sampler_without_replacement_cap_prefetch(self, drop_last): torch.manual_seed(0) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 45fede16cf5..1d46e5c7377 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -182,16 +182,11 @@ def dumps(self, path): path = Path(path) path.mkdir(exist_ok=True) - with open(path / "sampler_metadata.json", "w") as file: - json.dump( - self.state_dict(), - file, - ) + TensorDict(self.state_dict()).memmap(path) def loads(self, path): - with open(path / "sampler_metadata.json", "r") as file: - metadata = json.load(file) - self.load_state_dict(metadata) + sd = TensorDict.load_memmap(path).to_dict() + self.load_state_dict(sd) def _get_sample_list(self, storage: Storage, len_storage: int, batch_size: int): if storage is None: