SliceSampler¶
- class torchrl.data.replay_buffers.SliceSampler(*, num_slices: Optional[int] = None, slice_len: Optional[int] = None, end_key: Optional[NestedKey] = None, traj_key: Optional[NestedKey] = None, ends: Optional[Tensor] = None, trajectories: Optional[Tensor] = None, cache_values: bool = False, truncated_key: tensordict._nestedkey.NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: Union[bool, int, Tuple[bool | int, bool | int]] = False, use_gpu: torch.device | bool = False)[source]¶
Samples slices of data along the first dimension, given start and stop signals.
This class samples sub-trajectories with replacement. For a version without replacement, see
SliceSamplerWithoutReplacement
.Note
SliceSampler can be slow to retrieve the trajectory indices. To accelerate its execution, prefer using end_key over traj_key, and consider the following keyword arguments:
compile
,cache_values
anduse_gpu
.- Keyword Arguments:
num_slices (int) – the number of slices to be sampled. The batch-size must be greater or equal to the
num_slices
argument. Exclusive withslice_len
.slice_len (int) – the length of the slices to be sampled. The batch-size must be greater or equal to the
slice_len
argument and divisible by it. Exclusive withnum_slices
.end_key (NestedKey, optional) – the key indicating the end of a trajectory (or episode). Defaults to
("next", "done")
.traj_key (NestedKey, optional) – the key indicating the trajectories. Defaults to
"episode"
(commonly used across datasets in TorchRL).ends (torch.Tensor, optional) – a 1d boolean tensor containing the end of run signals. To be used whenever the
end_key
ortraj_key
is expensive to get, or when this signal is readily available. Must be used withcache_values=True
and cannot be used in conjunction withend_key
ortraj_key
. If provided, it is assumed that the storage is at capacity and that if the last element of theends
tensor isFalse
, the same trajectory spans across end and beginning.trajectories (torch.Tensor, optional) – a 1d integer tensor containing the run ids. To be used whenever the
end_key
ortraj_key
is expensive to get, or when this signal is readily available. Must be used withcache_values=True
and cannot be used in conjunction withend_key
ortraj_key
. If provided, it is assumed that the storage is at capacity and that if the last element of the trajectory tensor is identical to the first, the same trajectory spans across end and beginning.cache_values (bool, optional) –
to be used with static datasets. Will cache the start and end signal of the trajectory. This can be safely used even if the trajectory indices change during calls to
extend
as this operation will erase the cache.Warning
cache_values=True
will not work if the sampler is used with a storage that is extended by another buffer. For instance:>>> buffer0 = ReplayBuffer(storage=storage, ... sampler=SliceSampler(num_slices=8, cache_values=True), ... writer=ImmutableWriter()) >>> buffer1 = ReplayBuffer(storage=storage, ... sampler=other_sampler) >>> # Wrong! Does not erase the buffer from the sampler of buffer0 >>> buffer1.extend(data)
Warning
cache_values=True
will not work as expected if the buffer is shared between processes and one process is responsible for writing and one process for sampling, as erasing the cache can only be done locally.truncated_key (NestedKey, optional) – If not
None
, this argument indicates where a truncated signal should be written in the output data. This is used to indicate to value estimators where the provided trajectory breaks. Defaults to("next", "truncated")
. This feature only works withTensorDictReplayBuffer
instances (otherwise the truncated key is returned in the info dictionary returned by thesample()
method).strict_length (bool, optional) – if
False
, trajectories of length shorter than slice_len (or batch_size // num_slices) will be allowed to appear in the batch. IfTrue
, trajectories shorted than required will be filtered out. Be mindful that this can result in effective batch_size shorter than the one asked for! Trajectories can be split usingsplit_trajectories()
. Defaults toTrue
.compile (bool or dict of kwargs, optional) – if
True
, the bottleneck of thesample()
method will be compiled withcompile()
. Keyword arguments can also be passed to torch.compile with this arg. Defaults toFalse
.span (bool, int, Tuple[bool | int, bool | int], optional) – if provided, the sampled trajectory will span across the left and/or the right. This means that possibly fewer elements will be provided than what was required. A boolean value means that at least one element will be sampled per trajectory. An integer i means that at least slice_len - i samples will be gathered for each sampled trajectory. Using tuples allows a fine grained control over the span on the left (beginning of the stored trajectory) and on the right (end of the stored trajectory).
use_gpu (bool or torch.device) – if
True
(or is a device is passed), an accelerator will be used to retrieve the indices of the trajectory starts. This can significanlty accelerate the sampling when the buffer content is large. Defaults toFalse
.
Note
To recover the trajectory splits in the storage,
SliceSampler
will first attempt to find thetraj_key
entry in the storage. If it cannot be found, theend_key
will be used to reconstruct the episodes.Note
When using strict_length=False, it is recommended to use
split_trajectories()
to split the sampled trajectories. However, if two samples from the same episode are placed next to each other, this may produce incorrect results. To avoid this issue, consider one of these solutions:using a
TensorDictReplayBuffer
instance with the slice sampler>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSampler( ... slice_len=5, traj_key="episode",strict_length=False, ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(4), ... "episode": torch.ones(4),}, ... batch_size=[4] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> print(s) TensorDict( fields={ episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False), index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([46]), device=cpu, is_shared=False), obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([46]), device=cpu, is_shared=False) >>> t = split_trajectories(s, done_key="truncated") >>> print(t["obs"]) tensor([[73, 74, 75, 76, 77], [ 0, 1, 2, 3, 0], [ 0, 1, 2, 3, 0], [41, 42, 43, 44, 45], [ 0, 1, 2, 3, 0], [67, 68, 69, 70, 71], [27, 28, 29, 30, 31], [80, 81, 82, 83, 84], [17, 18, 19, 20, 21], [ 0, 1, 2, 3, 0]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.], [1., 1., 1., 1., 0.], [0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.]])
using a
SliceSamplerWithoutReplacement
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSamplerWithoutReplacement( ... slice_len=5, traj_key="episode",strict_length=False ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(4), ... "episode": torch.ones(4),}, ... batch_size=[4] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> t = split_trajectories(s, trajectory_key="episode") >>> print(t["obs"]) tensor([[75, 76, 77, 78, 79], [ 0, 1, 2, 3, 0]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.]])
Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer >>> from torchrl.data.replay_buffers.samplers import SliceSampler >>> torch.manual_seed(0) >>> rb = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(1_000_000), ... sampler=SliceSampler(cache_values=True, num_slices=10), ... batch_size=320, ... ) >>> episode = torch.zeros(1000, dtype=torch.int) >>> episode[:300] = 1 >>> episode[300:550] = 2 >>> episode[550:700] = 3 >>> episode[700:] = 4 >>> data = TensorDict( ... { ... "episode": episode, ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5), ... "act": torch.randn((20,)).expand(1000, 20), ... "other": torch.randn((20, 50)).expand(1000, 20, 50), ... }, [1000] ... ) >>> rb.extend(data) >>> sample = rb.sample() >>> print("sample:", sample) >>> print("episodes", sample.get("episode").unique()) episodes tensor([1, 2, 3, 4], dtype=torch.int32)
SliceSampler
is default-compatible with most of TorchRL’s datasets:Examples
>>> import torch >>> >>> from torchrl.data.datasets import RobosetExperienceReplay >>> from torchrl.data import SliceSampler >>> >>> torch.manual_seed(0) >>> num_slices = 10 >>> dataid = list(RobosetExperienceReplay.available_datasets)[0] >>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices)) >>> for batch in data: ... batch = batch.reshape(num_slices, -1) ... break >>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1)) check that each batch only has one episode: tensor([[19], [14], [ 8], [10], [13], [ 4], [ 2], [ 3], [22], [ 8]])