Shortcuts

PrioritizedSliceSampler

class torchrl.data.replay_buffers.PrioritizedSliceSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: torch.dtype = torch.float32, reduction: str = 'max', *, num_slices: int = None, slice_len: int = None, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False)[source]

Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.

This class samples sub-trajectories with replacement following a priority weighting presented in “Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.

Prioritized experience replay.” (https://arxiv.org/abs/1511.05952)

For more info see SliceSampler and PrioritizedSampler.

Parameters:
  • alpha (float) – exponent α determines how much prioritization is used, with α = 0 corresponding to the uniform case.

  • beta (float) – importance sampling negative exponent.

  • eps (float, optional) – delta added to the priorities to ensure that the buffer does not contain null priorities. Defaults to 1e-8.

  • reduction (str, optional) – the reduction method for multidimensional tensordicts (i.e., stored trajectory). Can be one of “max”, “min”, “median” or “mean”.

Keyword Arguments:
  • num_slices (int) – the number of slices to be sampled. The batch-size must be greater or equal to the num_slices argument. Exclusive with slice_len.

  • slice_len (int) – the length of the slices to be sampled. The batch-size must be greater or equal to the slice_len argument and divisible by it. Exclusive with num_slices.

  • end_key (NestedKey, optional) – the key indicating the end of a trajectory (or episode). Defaults to ("next", "done").

  • traj_key (NestedKey, optional) – the key indicating the trajectories. Defaults to "episode" (commonly used across datasets in TorchRL).

  • ends (torch.Tensor, optional) – a 1d boolean tensor containing the end of run signals. To be used whenever the end_key or traj_key is expensive to get, or when this signal is readily available. Must be used with cache_values=True and cannot be used in conjunction with end_key or traj_key.

  • trajectories (torch.Tensor, optional) – a 1d integer tensor containing the run ids. To be used whenever the end_key or traj_key is expensive to get, or when this signal is readily available. Must be used with cache_values=True and cannot be used in conjunction with end_key or traj_key.

  • cache_values (bool, optional) –

    to be used with static datasets. Will cache the start and end signal of the trajectory. This can be safely used even if the trajectory indices change during calls to extend as this operation will erase the cache.

    Warning

    cache_values=True will not work if the sampler is used with a storage that is extended by another buffer. For instance:

    >>> buffer0 = ReplayBuffer(storage=storage,
    ...     sampler=SliceSampler(num_slices=8, cache_values=True),
    ...     writer=ImmutableWriter())
    >>> buffer1 = ReplayBuffer(storage=storage,
    ...     sampler=other_sampler)
    >>> # Wrong! Does not erase the buffer from the sampler of buffer0
    >>> buffer1.extend(data)
    

    Warning

    cache_values=True will not work as expected if the buffer is shared between processes and one process is responsible for writing and one process for sampling, as erasing the cache can only be done locally.

  • truncated_key (NestedKey, optional) – If not None, this argument indicates where a truncated signal should be written in the output data. This is used to indicate to value estimators where the provided trajectory breaks. Defaults to ("next", "truncated"). This feature only works with TensorDictReplayBuffer instances (otherwise the truncated key is returned in the info dictionary returned by the sample() method).

  • strict_length (bool, optional) – if False, trajectories of length shorter than slice_len (or batch_size // num_slices) will be allowed to appear in the batch. If True, trajectories shorted than required will be filtered out. Be mindful that this can result in effective batch_size shorter than the one asked for! Trajectories can be split using split_trajectories(). Defaults to True.

  • compile (bool or dict of kwargs, optional) – if True, the bottleneck of the sample() method will be compiled with compile(). Keyword arguments can also be passed to torch.compile with this arg. Defaults to False.

Examples

>>> import torch
>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer, LazyMemmapStorage, PrioritizedSliceSampler
>>> from tensordict import TensorDict
>>> sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
>>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6)
>>> data = TensorDict(
...     {
...         "observation": torch.randn(9,16),
...         "action": torch.randn(9, 1),
...         "episode": torch.tensor([0,0,0,1,1,1,2,2,2], dtype=torch.long),
...         "steps": torch.tensor([0,1,2,0,1,2,0,1,2], dtype=torch.long),
...         ("next", "observation"): torch.randn(9,16),
...         ("next", "reward"): torch.randn(9,1),
...         ("next", "done"): torch.tensor([0,0,1,0,0,1,0,0,1], dtype=torch.bool).unsqueeze(1),
...     },
...     batch_size=[9],
... )
>>> rb.extend(data)
>>> sample, info = rb.sample(return_info=True)
>>> print("episode", sample["episode"].tolist())
episode [2, 2, 2, 2, 1, 1]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 1, 2]
>>> print("weight", info["_weight"].tolist())
weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
>>> priority = torch.tensor([0,3,3,0,0,0,1,1,1])
>>> rb.update_priority(torch.arange(0,9,1), priority=priority)
>>> sample, info = rb.sample(return_info=True)
>>> print("episode", sample["episode"].tolist())
episode [2, 2, 2, 2, 2, 2]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 0, 1]
>>> print("weight", info["_weight"].tolist())
weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
update_priority(index: Union[int, Tensor], priority: Union[float, Tensor]) None

Updates the priority of the data pointed by the index.

Parameters:
  • index (int or torch.Tensor) – indexes of the priorities to be updated.

  • priority (Number or torch.Tensor) – new priorities of the indexed elements.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources