Skip to content

Commit

Permalink
[BugFix] Fix dumps for SamplerWithoutReplacement (#2506)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 21, 2024
1 parent a27514c commit 9f6c21f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
28 changes: 28 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,6 +2200,34 @@ def test_sampler_without_rep_state_dict(self, backend):
s = new_replay_buffer.sample(batch_size=1)
assert (s.exclude("index") == 0).all()

def test_sampler_without_rep_dumps_loads(self, tmpdir):
d0 = tmpdir + "/save0"
d1 = tmpdir + "/save1"
d2 = tmpdir + "/dump"
replay_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(max_size=100, scratch_dir=d0, device="cpu"),
sampler=SamplerWithoutReplacement(drop_last=True),
batch_size=8,
)
replay_buffer2 = TensorDictReplayBuffer(
storage=LazyMemmapStorage(max_size=100, scratch_dir=d1, device="cpu"),
sampler=SamplerWithoutReplacement(drop_last=True),
batch_size=8,
)
td = TensorDict(
{"a": torch.arange(0, 27), ("b", "c"): torch.arange(1, 28)}, batch_size=[27]
)
replay_buffer.extend(td)
for _ in replay_buffer:
break
replay_buffer.dumps(d2)
replay_buffer2.loads(d2)
assert (
replay_buffer.sampler._sample_list == replay_buffer2.sampler._sample_list
).all()
s = replay_buffer2.sample(3)
assert (s["a"] == s["b", "c"] - 1).all()

@pytest.mark.parametrize("drop_last", [False, True])
def test_sampler_without_replacement_cap_prefetch(self, drop_last):
torch.manual_seed(0)
Expand Down
11 changes: 3 additions & 8 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,11 @@ def dumps(self, path):
path = Path(path)
path.mkdir(exist_ok=True)

with open(path / "sampler_metadata.json", "w") as file:
json.dump(
self.state_dict(),
file,
)
TensorDict(self.state_dict()).memmap(path)

def loads(self, path):
with open(path / "sampler_metadata.json", "r") as file:
metadata = json.load(file)
self.load_state_dict(metadata)
sd = TensorDict.load_memmap(path).to_dict()
self.load_state_dict(sd)

def _get_sample_list(self, storage: Storage, len_storage: int, batch_size: int):
if storage is None:
Expand Down

0 comments on commit 9f6c21f

Please sign in to comment.