SliceSamplerWithoutReplacement¶
- class torchrl.data.replay_buffers.SliceSamplerWithoutReplacement(*, num_slices: int | None = None, slice_len: int | None = None, drop_last: bool = False, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, shuffle: bool = True, compile: bool | dict = False)[source]¶
Samples slices of data along the first dimension, given start and stop signals, without replacement.
This class is to be used with static replay buffers or in between two replay buffer extensions. Extending the replay buffer will reset the the sampler, and continuous sampling without replacement is currently not allowed.
- Keyword Arguments:
drop_last (bool, optional) – if
True
, the last incomplete sample (if any) will be dropped. IfFalse
, this last sample will be kept. Defaults toFalse
.num_slices (int) – the number of slices to be sampled. The batch-size must be greater or equal to the
num_slices
argument. Exclusive withslice_len
.slice_len (int) – the length of the slices to be sampled. The batch-size must be greater or equal to the
slice_len
argument and divisible by it. Exclusive withnum_slices
.end_key (NestedKey, optional) – the key indicating the end of a trajectory (or episode). Defaults to
("next", "done")
.traj_key (NestedKey, optional) – the key indicating the trajectories. Defaults to
"episode"
(commonly used across datasets in TorchRL).ends (torch.Tensor, optional) – a 1d boolean tensor containing the end of run signals. To be used whenever the
end_key
ortraj_key
is expensive to get, or when this signal is readily available. Must be used withcache_values=True
and cannot be used in conjunction withend_key
ortraj_key
.trajectories (torch.Tensor, optional) – a 1d integer tensor containing the run ids. To be used whenever the
end_key
ortraj_key
is expensive to get, or when this signal is readily available. Must be used withcache_values=True
and cannot be used in conjunction withend_key
ortraj_key
.truncated_key (NestedKey, optional) – If not
None
, this argument indicates where a truncated signal should be written in the output data. This is used to indicate to value estimators where the provided trajectory breaks. Defaults to("next", "truncated")
. This feature only works withTensorDictReplayBuffer
instances (otherwise the truncated key is returned in the info dictionary returned by thesample()
method).strict_length (bool, optional) – if
False
, trajectories of length shorter than slice_len (or batch_size // num_slices) will be allowed to appear in the batch. IfTrue
, trajectories shorted than required will be filtered out. Be mindful that this can result in effective batch_size shorter than the one asked for! Trajectories can be split usingsplit_trajectories()
. Defaults toTrue
.shuffle (bool, optional) – if
False
, the order of the trajectories is not shuffled. Defaults toTrue
.compile (bool or dict of kwargs, optional) – if
True
, the bottleneck of thesample()
method will be compiled withcompile()
. Keyword arguments can also be passed to torch.compile with this arg. Defaults toFalse
.
Note
To recover the trajectory splits in the storage,
SliceSamplerWithoutReplacement
will first attempt to find thetraj_key
entry in the storage. If it cannot be found, theend_key
will be used to reconstruct the episodes.Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer >>> from torchrl.data.replay_buffers.samplers import SliceSamplerWithoutReplacement >>> >>> rb = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(1000), ... # asking for 10 slices for a total of 320 elements, ie, 10 trajectories of 32 transitions each ... sampler=SliceSamplerWithoutReplacement(num_slices=10), ... batch_size=320, ... ) >>> episode = torch.zeros(1000, dtype=torch.int) >>> episode[:300] = 1 >>> episode[300:550] = 2 >>> episode[550:700] = 3 >>> episode[700:] = 4 >>> data = TensorDict( ... { ... "episode": episode, ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5), ... "act": torch.randn((20,)).expand(1000, 20), ... "other": torch.randn((20, 50)).expand(1000, 20, 50), ... }, [1000] ... ) >>> rb.extend(data) >>> sample = rb.sample() >>> # since we want trajectories of 32 transitions but there are only 4 episodes to >>> # sample from, we only get 4 x 32 = 128 transitions in this batch >>> print("sample:", sample) >>> print("trajectories in sample", sample.get("episode").unique())
SliceSamplerWithoutReplacement
is default-compatible with most of TorchRL’s datasets, and allows users to consume datasets in a dataloader-like fashion:Examples
>>> import torch >>> >>> from torchrl.data.datasets import RobosetExperienceReplay >>> from torchrl.data import SliceSamplerWithoutReplacement >>> >>> torch.manual_seed(0) >>> num_slices = 10 >>> dataid = list(RobosetExperienceReplay.available_datasets)[0] >>> data = RobosetExperienceReplay(dataid, batch_size=320, ... sampler=SliceSamplerWithoutReplacement(num_slices=num_slices)) >>> # the last sample is kept, since drop_last=False by default >>> for i, batch in enumerate(data): ... print(batch.get("episode").unique()) tensor([ 5, 6, 8, 11, 12, 14, 16, 17, 19, 24]) tensor([ 1, 2, 7, 9, 10, 13, 15, 18, 21, 22]) tensor([ 0, 3, 4, 20, 23])