@@ -802,6 +802,112 @@ class SliceSampler(Sampler):
802
802
attempt to find the ``traj_key`` entry in the storage. If it cannot be
803
803
found, the ``end_key`` will be used to reconstruct the episodes.
804
804
805
+ .. note:: When using `strict_length=False`, it is recommended to use
806
+ :func:`~torchrl.collectors.utils.split_trajectories` to split the sampled trajectories.
807
+ However, if two samples from the same episode are placed next to each other,
808
+ this may produce incorrect results. To avoid this issue, consider one of these solutions:
809
+
810
+ - using a :class:`~torchrl.data.TensorDictReplayBuffer` instance with the slice sampler
811
+
812
+ >>> import torch
813
+ >>> from tensordict import TensorDict
814
+ >>> from torchrl.collectors.utils import split_trajectories
815
+ >>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
816
+ >>>
817
+ >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000),
818
+ ... sampler=SliceSampler(
819
+ ... slice_len=5, traj_key="episode",strict_length=False,
820
+ ... ))
821
+ ...
822
+ >>> ep_1 = TensorDict(
823
+ ... {"obs": torch.arange(100),
824
+ ... "episode": torch.zeros(100),},
825
+ ... batch_size=[100]
826
+ ... )
827
+ >>> ep_2 = TensorDict(
828
+ ... {"obs": torch.arange(4),
829
+ ... "episode": torch.ones(4),},
830
+ ... batch_size=[4]
831
+ ... )
832
+ >>> rb.extend(ep_1)
833
+ >>> rb.extend(ep_2)
834
+ >>>
835
+ >>> s = rb.sample(50)
836
+ >>> print(s)
837
+ TensorDict(
838
+ fields={
839
+ episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False),
840
+ index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False),
841
+ next: TensorDict(
842
+ fields={
843
+ done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
844
+ terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
845
+ truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
846
+ batch_size=torch.Size([46]),
847
+ device=cpu,
848
+ is_shared=False),
849
+ obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)},
850
+ batch_size=torch.Size([46]),
851
+ device=cpu,
852
+ is_shared=False)
853
+ >>> t = split_trajectories(s, done_key="truncated")
854
+ >>> print(t["obs"])
855
+ tensor([[73, 74, 75, 76, 77],
856
+ [ 0, 1, 2, 3, 0],
857
+ [ 0, 1, 2, 3, 0],
858
+ [41, 42, 43, 44, 45],
859
+ [ 0, 1, 2, 3, 0],
860
+ [67, 68, 69, 70, 71],
861
+ [27, 28, 29, 30, 31],
862
+ [80, 81, 82, 83, 84],
863
+ [17, 18, 19, 20, 21],
864
+ [ 0, 1, 2, 3, 0]])
865
+ >>> print(t["episode"])
866
+ tensor([[0., 0., 0., 0., 0.],
867
+ [1., 1., 1., 1., 0.],
868
+ [1., 1., 1., 1., 0.],
869
+ [0., 0., 0., 0., 0.],
870
+ [1., 1., 1., 1., 0.],
871
+ [0., 0., 0., 0., 0.],
872
+ [0., 0., 0., 0., 0.],
873
+ [0., 0., 0., 0., 0.],
874
+ [0., 0., 0., 0., 0.],
875
+ [1., 1., 1., 1., 0.]])
876
+
877
+ - using a :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`
878
+
879
+ >>> import torch
880
+ >>> from tensordict import TensorDict
881
+ >>> from torchrl.collectors.utils import split_trajectories
882
+ >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
883
+ >>>
884
+ >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
885
+ ... sampler=SliceSamplerWithoutReplacement(
886
+ ... slice_len=5, traj_key="episode",strict_length=False
887
+ ... ))
888
+ ...
889
+ >>> ep_1 = TensorDict(
890
+ ... {"obs": torch.arange(100),
891
+ ... "episode": torch.zeros(100),},
892
+ ... batch_size=[100]
893
+ ... )
894
+ >>> ep_2 = TensorDict(
895
+ ... {"obs": torch.arange(4),
896
+ ... "episode": torch.ones(4),},
897
+ ... batch_size=[4]
898
+ ... )
899
+ >>> rb.extend(ep_1)
900
+ >>> rb.extend(ep_2)
901
+ >>>
902
+ >>> s = rb.sample(50)
903
+ >>> t = split_trajectories(s, trajectory_key="episode")
904
+ >>> print(t["obs"])
905
+ tensor([[75, 76, 77, 78, 79],
906
+ [ 0, 1, 2, 3, 0]])
907
+ >>> print(t["episode"])
908
+ tensor([[0., 0., 0., 0., 0.],
909
+ [1., 1., 1., 1., 0.]])
910
+
805
911
Examples:
806
912
>>> import torch
807
913
>>> from tensordict import TensorDict
@@ -1427,6 +1533,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
1427
1533
class SliceSamplerWithoutReplacement (SliceSampler , SamplerWithoutReplacement ):
1428
1534
"""Samples slices of data along the first dimension, given start and stop signals, without replacement.
1429
1535
1536
+ In this context, ``without replacement`` means that the same element (NOT trajectory) will not be sampled twice
1537
+ before the counter is automatically reset. Within a single sample, however, only one slice of a given trajectory
1538
+ will appear (see example below).
1539
+
1430
1540
This class is to be used with static replay buffers or in between two
1431
1541
replay buffer extensions. Extending the replay buffer will reset the
1432
1542
the sampler, and continuous sampling without replacement is currently not
@@ -1533,6 +1643,52 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
1533
1643
tensor([ 1, 2, 7, 9, 10, 13, 15, 18, 21, 22])
1534
1644
tensor([ 0, 3, 4, 20, 23])
1535
1645
1646
+ When requesting a large total number of samples with few trajectories and small span, the batch will contain
1647
+ only at most one sample of each trajectory:
1648
+
1649
+ Examples:
1650
+ >>> import torch
1651
+ >>> from tensordict import TensorDict
1652
+ >>> from torchrl.collectors.utils import split_trajectories
1653
+ >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
1654
+ >>>
1655
+ >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
1656
+ ... sampler=SliceSamplerWithoutReplacement(
1657
+ ... slice_len=5, traj_key="episode",strict_length=False
1658
+ ... ))
1659
+ ...
1660
+ >>> ep_1 = TensorDict(
1661
+ ... {"obs": torch.arange(100),
1662
+ ... "episode": torch.zeros(100),},
1663
+ ... batch_size=[100]
1664
+ ... )
1665
+ >>> ep_2 = TensorDict(
1666
+ ... {"obs": torch.arange(51),
1667
+ ... "episode": torch.ones(51),},
1668
+ ... batch_size=[51]
1669
+ ... )
1670
+ >>> rb.extend(ep_1)
1671
+ >>> rb.extend(ep_2)
1672
+ >>>
1673
+ >>> s = rb.sample(50)
1674
+ >>> t = split_trajectories(s, trajectory_key="episode")
1675
+ >>> print(t["obs"])
1676
+ tensor([[14, 15, 16, 17, 18],
1677
+ [ 3, 4, 5, 6, 7]])
1678
+ >>> print(t["episode"])
1679
+ tensor([[0., 0., 0., 0., 0.],
1680
+ [1., 1., 1., 1., 1.]])
1681
+ >>>
1682
+ >>> s = rb.sample(50)
1683
+ >>> t = split_trajectories(s, trajectory_key="episode")
1684
+ >>> print(t["obs"])
1685
+ tensor([[ 4, 5, 6, 7, 8],
1686
+ [26, 27, 28, 29, 30]])
1687
+ >>> print(t["episode"])
1688
+ tensor([[0., 0., 0., 0., 0.],
1689
+ [1., 1., 1., 1., 1.]])
1690
+
1691
+
1536
1692
"""
1537
1693
1538
1694
def __init__ (
0 commit comments