Shortcuts

PrioritizedSampler

class torchrl.data.replay_buffers.PrioritizedSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: dtype = torch.float32, reduction: str = 'max', max_priority_within_buffer: bool = False)[source]

Prioritized sampler for replay buffer.

Presented in “Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay.” (https://arxiv.org/abs/1511.05952)

Parameters:
  • max_capacity (int) – maximum capacity of the buffer.

  • alpha (float) – exponent α determines how much prioritization is used, with α = 0 corresponding to the uniform case.

  • beta (float) – importance sampling negative exponent.

  • eps (float, optional) – delta added to the priorities to ensure that the buffer does not contain null priorities. Defaults to 1e-8.

  • reduction (str, optional) – the reduction method for multidimensional tensordicts (ie stored trajectory). Can be one of “max”, “min”, “median” or “mean”.

  • max_priority_within_buffer (bool, optional) – if True, the max-priority is tracked within the buffer. When False, the max-priority tracks the maximum value since the instantiation of the sampler.

Examples

>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> rb.add(data_0)
>>> rb.add(data_1)
>>> rb.update_priority(torch.tensor([0, 1]), priority=priority)
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample)
TensorDict(
        fields={
            action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
            obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
            priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
            reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([10]),
        device=cpu,
        is_shared=False)
>>> print(info)
{'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
       1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}

Note

Using a TensorDictReplayBuffer can smoothen the process of updating the priorities:

>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = TDRB(
...     storage=LazyTensorStorage(10),
...     sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
...     priority_key="priority",  # This kwarg isn't present in regular RBs
... )
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> data = torch.stack([data_0, data_1])
>>> rb.extend(data)
>>> rb.update_priority(data)  # Reads the "priority" key as indicated in the constructor
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample['index'])  # The index is packed with the tensordict
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
update_priority(index: Union[int, Tensor], priority: Union[float, Tensor], *, storage: torchrl.data.replay_buffers.storages.TensorStorage | None = None) None[source]

Updates the priority of the data pointed by the index.

Parameters:
  • index (int or torch.Tensor) – indexes of the priorities to be updated.

  • priority (Number or torch.Tensor) – new priorities of the indexed elements.

Keyword Arguments:

storage (Storage, optional) – a storage used to map the Nd index size to the 1d size of the sum_tree and min_tree. Only required whenever index.ndim > 2.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources