PrioritizedReplayBuffer¶
- class torchrl.data.PrioritizedReplayBuffer(*, alpha: float, beta: float, eps: float = 1e-08, dtype: dtype = torch.float32, storage: Optional[Storage] = None, collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, transform: Optional[Transform] = None, batch_size: Optional[int] = None)[source]¶
Prioritized replay buffer.
All arguments are keyword-only arguments.
- Presented in
“Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay.” (https://arxiv.org/abs/1511.05952)
- Parameters:
alpha (float) – exponent α determines how much prioritization is used, with α = 0 corresponding to the uniform case.
beta (float) – importance sampling negative exponent.
eps (float) – delta added to the priorities to ensure that the buffer does not contain null priorities.
storage (Storage, optional) – the storage to be used. If none is provided a default
ListStorage
withmax_size
of1_000
will be created.collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. The default value will be decided based on the storage type.
pin_memory (bool) – whether pin_memory() should be called on the rb samples.
prefetch (int, optional) – number of next batches to be prefetched using multithreading. Defaults to None (no prefetching).
transform (Transform, optional) – Transform to be executed when sample() is called. To chain transforms use the
Compose
class. Transforms should be used withtensordict.TensorDict
content. If used with other structures, the transforms should be encoded with a"data"
leading key that will be used to construct a tensordict from the non-tensordict content.batch_size (int, optional) –
the batch size to be used when sample() is called. .. note:
The batch-size can be specified at construction time via the ``batch_size`` argument, or at sampling time. The former should be preferred whenever the batch-size is consistent across the experiment. If the batch-size is likely to change, it can be passed to the :meth:`~.sample` method. This option is incompatible with prefetching (since this requires to know the batch-size in advance) as well as with samplers that have a ``drop_last`` argument.
Note
Generic prioritized replay buffers (ie. non-tensordict backed) require calling
sample()
with thereturn_info
argument set toTrue
to have access to the indices, and hence update the priority. Usingtensordict.TensorDict
and the relatedTensorDictPrioritizedReplayBuffer
simplifies this process.Examples
>>> import torch >>> >>> from torchrl.data import ListStorage, PrioritizedReplayBuffer >>> >>> torch.manual_seed(0) >>> >>> rb = PrioritizedReplayBuffer(alpha=0.7, beta=0.9, storage=ListStorage(10)) >>> data = range(10) >>> rb.extend(data) >>> sample = rb.sample(3) >>> print(sample) tensor([1, 0, 1]) >>> # get the info to find what the indices are >>> sample, info = rb.sample(5, return_info=True) >>> print(sample, info) tensor([2, 7, 4, 3, 5]) {'_weight': array([1., 1., 1., 1., 1.], dtype=float32), 'index': array([2, 7, 4, 3, 5])} >>> # update priority >>> priority = torch.ones(5) * 5 >>> rb.update_priority(info["index"], priority) >>> # and now a new sample, the weights should be updated >>> sample, info = rb.sample(5, return_info=True) >>> print(sample, info) tensor([2, 5, 2, 2, 5]) {'_weight': array([0.36278465, 0.36278465, 0.36278465, 0.36278465, 0.36278465], dtype=float32), 'index': array([2, 5, 2, 2, 5])}
- add(data: Any) int ¶
Add a single element to the replay buffer.
- Parameters:
data (Any) – data to be added to the replay buffer
- Returns:
index where the data lives in the replay buffer.
- append_transform(transform: Transform) None ¶
Appends transform at the end.
Transforms are applied in order when sample is called.
- Parameters:
transform (Transform) – The transform to be appended
- empty()¶
Empties the replay buffer and reset cursor to 0.
- extend(data: Sequence) Tensor ¶
Extends the replay buffer with one or more elements contained in an iterable.
If present, the inverse transforms will be called.`
- Parameters:
data (iterable) – collection of data to be added to the replay buffer.
- Returns:
Indices of the data added to the replay buffer.
- insert_transform(index: int, transform: Transform) None ¶
Inserts transform.
Transforms are executed in order when sample is called.
- Parameters:
index (int) – Position to insert the transform.
transform (Transform) – The transform to be appended
- sample(batch_size: Optional[int] = None, return_info: bool = False) Any ¶
Samples a batch of data from the replay buffer.
Uses Sampler to sample indices, and retrieves them from Storage.
- Parameters:
batch_size (int, optional) – size of data to be collected. If none is provided, this method will sample a batch-size as indicated by the sampler.
return_info (bool) – whether to return info. If True, the result is a tuple (data, info). If False, the result is the data.
- Returns:
A batch of data selected in the replay buffer. A tuple containing this batch and info if return_info flag is set to True.