Skip to content

Commit e557788

Browse files
committed
Update
[ghstack-poisoned]
1 parent 4aa576d commit e557788

File tree

1 file changed

+156
-0
lines changed

1 file changed

+156
-0
lines changed

torchrl/data/replay_buffers/samplers.py

+156
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,112 @@ class SliceSampler(Sampler):
802802
attempt to find the ``traj_key`` entry in the storage. If it cannot be
803803
found, the ``end_key`` will be used to reconstruct the episodes.
804804
805+
.. note:: When using `strict_length=False`, it is recommended to use
806+
:func:`~torchrl.collectors.utils.split_trajectories` to split the sampled trajectories.
807+
However, if two samples from the same episode are placed next to each other,
808+
this may produce incorrect results. To avoid this issue, consider one of these solutions:
809+
810+
- using a :class:`~torchrl.data.TensorDictReplayBuffer` instance with the slice sampler
811+
812+
>>> import torch
813+
>>> from tensordict import TensorDict
814+
>>> from torchrl.collectors.utils import split_trajectories
815+
>>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
816+
>>>
817+
>>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000),
818+
... sampler=SliceSampler(
819+
... slice_len=5, traj_key="episode",strict_length=False,
820+
... ))
821+
...
822+
>>> ep_1 = TensorDict(
823+
... {"obs": torch.arange(100),
824+
... "episode": torch.zeros(100),},
825+
... batch_size=[100]
826+
... )
827+
>>> ep_2 = TensorDict(
828+
... {"obs": torch.arange(4),
829+
... "episode": torch.ones(4),},
830+
... batch_size=[4]
831+
... )
832+
>>> rb.extend(ep_1)
833+
>>> rb.extend(ep_2)
834+
>>>
835+
>>> s = rb.sample(50)
836+
>>> print(s)
837+
TensorDict(
838+
fields={
839+
episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False),
840+
index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False),
841+
next: TensorDict(
842+
fields={
843+
done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
844+
terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
845+
truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
846+
batch_size=torch.Size([46]),
847+
device=cpu,
848+
is_shared=False),
849+
obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)},
850+
batch_size=torch.Size([46]),
851+
device=cpu,
852+
is_shared=False)
853+
>>> t = split_trajectories(s, done_key="truncated")
854+
>>> print(t["obs"])
855+
tensor([[73, 74, 75, 76, 77],
856+
[ 0, 1, 2, 3, 0],
857+
[ 0, 1, 2, 3, 0],
858+
[41, 42, 43, 44, 45],
859+
[ 0, 1, 2, 3, 0],
860+
[67, 68, 69, 70, 71],
861+
[27, 28, 29, 30, 31],
862+
[80, 81, 82, 83, 84],
863+
[17, 18, 19, 20, 21],
864+
[ 0, 1, 2, 3, 0]])
865+
>>> print(t["episode"])
866+
tensor([[0., 0., 0., 0., 0.],
867+
[1., 1., 1., 1., 0.],
868+
[1., 1., 1., 1., 0.],
869+
[0., 0., 0., 0., 0.],
870+
[1., 1., 1., 1., 0.],
871+
[0., 0., 0., 0., 0.],
872+
[0., 0., 0., 0., 0.],
873+
[0., 0., 0., 0., 0.],
874+
[0., 0., 0., 0., 0.],
875+
[1., 1., 1., 1., 0.]])
876+
877+
- using a :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`
878+
879+
>>> import torch
880+
>>> from tensordict import TensorDict
881+
>>> from torchrl.collectors.utils import split_trajectories
882+
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
883+
>>>
884+
>>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
885+
... sampler=SliceSamplerWithoutReplacement(
886+
... slice_len=5, traj_key="episode",strict_length=False
887+
... ))
888+
...
889+
>>> ep_1 = TensorDict(
890+
... {"obs": torch.arange(100),
891+
... "episode": torch.zeros(100),},
892+
... batch_size=[100]
893+
... )
894+
>>> ep_2 = TensorDict(
895+
... {"obs": torch.arange(4),
896+
... "episode": torch.ones(4),},
897+
... batch_size=[4]
898+
... )
899+
>>> rb.extend(ep_1)
900+
>>> rb.extend(ep_2)
901+
>>>
902+
>>> s = rb.sample(50)
903+
>>> t = split_trajectories(s, trajectory_key="episode")
904+
>>> print(t["obs"])
905+
tensor([[75, 76, 77, 78, 79],
906+
[ 0, 1, 2, 3, 0]])
907+
>>> print(t["episode"])
908+
tensor([[0., 0., 0., 0., 0.],
909+
[1., 1., 1., 1., 0.]])
910+
805911
Examples:
806912
>>> import torch
807913
>>> from tensordict import TensorDict
@@ -1427,6 +1533,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
14271533
class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
14281534
"""Samples slices of data along the first dimension, given start and stop signals, without replacement.
14291535
1536+
In this context, ``without replacement`` means that the same element (NOT trajectory) will not be sampled twice
1537+
before the counter is automatically reset. Within a single sample, however, only one slice of a given trajectory
1538+
will appear (see example below).
1539+
14301540
This class is to be used with static replay buffers or in between two
14311541
replay buffer extensions. Extending the replay buffer will reset the
14321542
the sampler, and continuous sampling without replacement is currently not
@@ -1533,6 +1643,52 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
15331643
tensor([ 1, 2, 7, 9, 10, 13, 15, 18, 21, 22])
15341644
tensor([ 0, 3, 4, 20, 23])
15351645
1646+
When requesting a large total number of samples with few trajectories and small span, the batch will contain
1647+
only at most one sample of each trajectory:
1648+
1649+
Examples:
1650+
>>> import torch
1651+
>>> from tensordict import TensorDict
1652+
>>> from torchrl.collectors.utils import split_trajectories
1653+
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
1654+
>>>
1655+
>>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
1656+
... sampler=SliceSamplerWithoutReplacement(
1657+
... slice_len=5, traj_key="episode",strict_length=False
1658+
... ))
1659+
...
1660+
>>> ep_1 = TensorDict(
1661+
... {"obs": torch.arange(100),
1662+
... "episode": torch.zeros(100),},
1663+
... batch_size=[100]
1664+
... )
1665+
>>> ep_2 = TensorDict(
1666+
... {"obs": torch.arange(51),
1667+
... "episode": torch.ones(51),},
1668+
... batch_size=[51]
1669+
... )
1670+
>>> rb.extend(ep_1)
1671+
>>> rb.extend(ep_2)
1672+
>>>
1673+
>>> s = rb.sample(50)
1674+
>>> t = split_trajectories(s, trajectory_key="episode")
1675+
>>> print(t["obs"])
1676+
tensor([[14, 15, 16, 17, 18],
1677+
[ 3, 4, 5, 6, 7]])
1678+
>>> print(t["episode"])
1679+
tensor([[0., 0., 0., 0., 0.],
1680+
[1., 1., 1., 1., 1.]])
1681+
>>>
1682+
>>> s = rb.sample(50)
1683+
>>> t = split_trajectories(s, trajectory_key="episode")
1684+
>>> print(t["obs"])
1685+
tensor([[ 4, 5, 6, 7, 8],
1686+
[26, 27, 28, 29, 30]])
1687+
>>> print(t["episode"])
1688+
tensor([[0., 0., 0., 0., 0.],
1689+
[1., 1., 1., 1., 1.]])
1690+
1691+
15361692
"""
15371693

15381694
def __init__(

0 commit comments

Comments
 (0)