Shortcuts

SliceSampler

class torchrl.data.replay_buffers.SliceSampler(*, num_slices: int = None, slice_len: int = None, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: bool | int | Tuple[bool | int, bool | int] = False)[source]

Samples slices of data along the first dimension, given start and stop signals.

This class samples sub-trajectories with replacement. For a version without replacement, see SliceSamplerWithoutReplacement.

Keyword Arguments:
  • num_slices (int) – the number of slices to be sampled. The batch-size must be greater or equal to the num_slices argument. Exclusive with slice_len.

  • slice_len (int) – the length of the slices to be sampled. The batch-size must be greater or equal to the slice_len argument and divisible by it. Exclusive with num_slices.

  • end_key (NestedKey, optional) – the key indicating the end of a trajectory (or episode). Defaults to ("next", "done").

  • traj_key (NestedKey, optional) – the key indicating the trajectories. Defaults to "episode" (commonly used across datasets in TorchRL).

  • ends (torch.Tensor, optional) – a 1d boolean tensor containing the end of run signals. To be used whenever the end_key or traj_key is expensive to get, or when this signal is readily available. Must be used with cache_values=True and cannot be used in conjunction with end_key or traj_key. If provided, it is assumed that the storage is at capacity and that if the last element of the ends tensor is False, the same trajectory spans across end and beginning.

  • trajectories (torch.Tensor, optional) – a 1d integer tensor containing the run ids. To be used whenever the end_key or traj_key is expensive to get, or when this signal is readily available. Must be used with cache_values=True and cannot be used in conjunction with end_key or traj_key. If provided, it is assumed that the storage is at capacity and that if the last element of the trajectory tensor is identical to the first, the same trajectory spans across end and beginning.

  • cache_values (bool, optional) –

    to be used with static datasets. Will cache the start and end signal of the trajectory. This can be safely used even if the trajectory indices change during calls to extend as this operation will erase the cache.

    Warning

    cache_values=True will not work if the sampler is used with a storage that is extended by another buffer. For instance:

    >>> buffer0 = ReplayBuffer(storage=storage,
    ...     sampler=SliceSampler(num_slices=8, cache_values=True),
    ...     writer=ImmutableWriter())
    >>> buffer1 = ReplayBuffer(storage=storage,
    ...     sampler=other_sampler)
    >>> # Wrong! Does not erase the buffer from the sampler of buffer0
    >>> buffer1.extend(data)
    

    Warning

    cache_values=True will not work as expected if the buffer is shared between processes and one process is responsible for writing and one process for sampling, as erasing the cache can only be done locally.

  • truncated_key (NestedKey, optional) – If not None, this argument indicates where a truncated signal should be written in the output data. This is used to indicate to value estimators where the provided trajectory breaks. Defaults to ("next", "truncated"). This feature only works with TensorDictReplayBuffer instances (otherwise the truncated key is returned in the info dictionary returned by the sample() method).

  • strict_length (bool, optional) – if False, trajectories of length shorter than slice_len (or batch_size // num_slices) will be allowed to appear in the batch. If True, trajectories shorted than required will be filtered out. Be mindful that this can result in effective batch_size shorter than the one asked for! Trajectories can be split using split_trajectories(). Defaults to True.

  • compile (bool or dict of kwargs, optional) – if True, the bottleneck of the sample() method will be compiled with compile(). Keyword arguments can also be passed to torch.compile with this arg. Defaults to False.

  • span (bool, int, Tuple[bool | int, bool | int], optional) – if provided, the sampled trajectory will span across the left and/or the right. This means that possibly fewer elements will be provided than what was required. A boolean value means that at least one element will be sampled per trajectory. An integer i means that at least slice_len - i samples will be gathered for each sampled trajectory. Using tuples allows a fine grained control over the span on the left (beginning of the stored trajectory) and on the right (end of the stored trajectory).

Note

To recover the trajectory splits in the storage, SliceSampler will first attempt to find the traj_key entry in the storage. If it cannot be found, the end_key will be used to reconstruct the episodes.

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
>>> torch.manual_seed(0)
>>> rb = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(1_000_000),
...     sampler=SliceSampler(cache_values=True, num_slices=10),
...     batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
...     {
...         "episode": episode,
...         "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
...         "act": torch.randn((20,)).expand(1000, 20),
...         "other": torch.randn((20, 50)).expand(1000, 20, 50),
...     }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> print("sample:", sample)
>>> print("episodes", sample.get("episode").unique())
episodes tensor([1, 2, 3, 4], dtype=torch.int32)

SliceSampler is default-compatible with most of TorchRL’s datasets:

Examples

>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSampler
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
>>> for batch in data:
...     batch = batch.reshape(num_slices, -1)
...     break
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
check that each batch only has one episode: tensor([[19],
        [14],
        [ 8],
        [10],
        [13],
        [ 4],
        [ 2],
        [ 3],
        [22],
        [ 8]])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources