Shortcuts

ReplayBuffer

class torchrl.data.ReplayBuffer(*, storage: Storage | None = None, sampler: Sampler | None = None, writer: Writer | None = None, collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, transform: 'Transform' | None = None, batch_size: int | None = None, dim_extend: int | None = None, checkpointer: 'StorageCheckpointerBase' | None = None, generator: torch.Generator | None = None, shared: bool = False)[source]

A generic, composable replay buffer class.

Keyword Arguments:
  • storage (Storage, optional) – the storage to be used. If none is provided a default ListStorage with max_size of 1_000 will be created.

  • sampler (Sampler, optional) – the sampler to be used. If none is provided, a default RandomSampler will be used.

  • writer (Writer, optional) – the writer to be used. If none is provided a default RoundRobinWriter will be used.

  • collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched loading from a map-style dataset. The default value will be decided based on the storage type.

  • pin_memory (bool) – whether pin_memory() should be called on the rb samples.

  • prefetch (int, optional) – number of next batches to be prefetched using multithreading. Defaults to None (no prefetching).

  • transform (Transform, optional) – Transform to be executed when sample() is called. To chain transforms use the Compose class. Transforms should be used with tensordict.TensorDict content. A generic callable can also be passed if the replay buffer is used with PyTree structures (see example below).

  • batch_size (int, optional) –

    the batch size to be used when sample() is called. .. note:

    The batch-size can be specified at construction time via the
    ``batch_size`` argument, or at sampling time. The former should
    be preferred whenever the batch-size is consistent across the
    experiment. If the batch-size is likely to change, it can be
    passed to the :meth:`~.sample` method. This option is
    incompatible with prefetching (since this requires to know the
    batch-size in advance) as well as with samplers that have a
    ``drop_last`` argument.
    

  • dim_extend (int, optional) –

    indicates the dim to consider for extension when calling extend(). Defaults to storage.ndim-1. When using dim_extend > 0, we recommend using the ndim argument in the storage instantiation if that argument is available, to let storages know that the data is multi-dimensional and keep consistent notions of storage-capacity and batch-size during sampling.

    Note

    This argument has no effect on add() and therefore should be used with caution when both add() and extend() are used in a codebase. For example:

    >>> data = torch.zeros(3, 4)
    >>> rb = ReplayBuffer(
    ...     storage=LazyTensorStorage(10, ndim=2),
    ...     dim_extend=1)
    >>> # these two approaches are equivalent:
    >>> for d in data.unbind(1):
    ...     rb.add(d)
    >>> rb.extend(data)
    

  • generator (torch.Generator, optional) –

    a generator to use for sampling. Using a dedicated generator for the replay buffer can allow a fine-grained control over seeding, for instance keeping the global seed different but the RB seed identical for distributed jobs. Defaults to None (global default generator).

    Warning

    As of now, the generator has no effect on the transforms.

  • shared (bool, optional) – whether the buffer will be shared using multiprocessing or not. Defaults to False.

Examples

>>> import torch
>>>
>>> from torchrl.data import ReplayBuffer, ListStorage
>>>
>>> torch.manual_seed(0)
>>> rb = ReplayBuffer(
...     storage=ListStorage(max_size=1000),
...     batch_size=5,
... )
>>> # populate the replay buffer and get the item indices
>>> data = range(10)
>>> indices = rb.extend(data)
>>> # sample will return as many elements as specified in the constructor
>>> sample = rb.sample()
>>> print(sample)
tensor([4, 9, 3, 0, 3])
>>> # Passing the batch-size to the sample method overrides the one in the constructor
>>> sample = rb.sample(batch_size=3)
>>> print(sample)
tensor([9, 7, 3])
>>> # one cans sample using the ``sample`` method or iterate over the buffer
>>> for i, batch in enumerate(rb):
...     print(i, batch)
...     if i == 3:
...         break
0 tensor([7, 3, 1, 6, 6])
1 tensor([9, 8, 6, 6, 8])
2 tensor([4, 3, 6, 9, 1])
3 tensor([4, 4, 1, 9, 9])

Replay buffers accept any kind of data. Not all storage types will work, as some expect numerical data only, but the default ListStorage will:

Examples

>>> torch.manual_seed(0)
>>> buffer = ReplayBuffer(storage=ListStorage(100), collate_fn=lambda x: x)
>>> indices = buffer.extend(["a", 1, None])
>>> buffer.sample(3)
[None, 'a', None]

The TensorStorage, LazyMemmapStorage and LazyTensorStorage also work with any PyTree structure (a PyTree is a nested structure of arbitrary depth made of dicts, lists or tuples where the leaves are tensors) provided that it only contains tensor data.

Examples

>>> from torch.utils._pytree import tree_map
>>> def transform(x):
...     # Zeros all the data in the pytree
...     return tree_map(lambda y: y * 0, x)
>>> rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=transform)
>>> data = {
...     "a": torch.randn(3),
...     "b": {"c": (torch.zeros(2), [torch.ones(1)])},
...     30: -torch.ones(()),
... }
>>> rb.add(data)
>>> # The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
>>> s = rb.sample(10)
>>> # let's check that our transform did its job:
>>> def assert0(x):
>>>     assert (x == 0).all()
>>> tree_map(assert0, s)
add(data: Any) int[source]

Add a single element to the replay buffer.

Parameters:

data (Any) – data to be added to the replay buffer

Returns:

index where the data lives in the replay buffer.

append_transform(transform: Transform, *, invert: bool = False) ReplayBuffer[source]

Appends transform at the end.

Transforms are applied in order when sample is called.

Parameters:

transform (Transform) – The transform to be appended

Keyword Arguments:

invert (bool, optional) – if True, the transform will be inverted (forward calls will be called during writing and inverse calls during reading). Defaults to False.

Example

>>> rb = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=4)
>>> data = TensorDict({"a": torch.zeros(10)}, [10])
>>> def t(data):
...     data += 1
...     return data
>>> rb.append_transform(t, invert=True)
>>> rb.extend(data)
>>> assert (data == 1).all()
dump(*args, **kwargs)[source]

Alias for dumps().

dumps(path)[source]

Saves the replay buffer on disk at the specified path.

Parameters:

path (Path or str) – path where to save the replay buffer.

Examples

>>> import tempfile
>>> import tqdm
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
>>> import torch
>>> from tensordict import TensorDict
>>> # Build and populate the replay buffer
>>> S = 1_000_000
>>> sampler = PrioritizedSampler(S, 1.1, 1.0)
>>> # sampler = RandomSampler()
>>> storage = LazyMemmapStorage(S)
>>> rb = TensorDictReplayBuffer(storage=storage, sampler=sampler)
>>>
>>> for _ in tqdm.tqdm(range(100)):
...     td = TensorDict({"obs": torch.randn(100, 3, 4), "next": {"obs": torch.randn(100, 3, 4)}, "td_error": torch.rand(100)}, [100])
...     rb.extend(td)
...     sample = rb.sample(32)
...     rb.update_tensordict_priority(sample)
>>> # save and load the buffer
>>> with tempfile.TemporaryDirectory() as tmpdir:
...     rb.dumps(tmpdir)
...
...     sampler = PrioritizedSampler(S, 1.1, 1.0)
...     # sampler = RandomSampler()
...     storage = LazyMemmapStorage(S)
...     rb_load = TensorDictReplayBuffer(storage=storage, sampler=sampler)
...     rb_load.loads(tmpdir)
...     assert len(rb) == len(rb_load)
empty()[source]

Empties the replay buffer and reset cursor to 0.

extend(data: Sequence) Tensor[source]

Extends the replay buffer with one or more elements contained in an iterable.

If present, the inverse transforms will be called.`

Parameters:

data (iterable) – collection of data to be added to the replay buffer.

Returns:

Indices of the data added to the replay buffer.

Warning

extend() can have an ambiguous signature when dealing with lists of values, which should be interpreted either as PyTree (in which case all elements in the list will be put in a slice in the stored PyTree in the storage) or a list of values to add one at a time. To solve this, TorchRL makes the clear-cut distinction between list and tuple: a tuple will be viewed as a PyTree, a list (at the root level) will be interpreted as a stack of values to add one at a time to the buffer. For ListStorage instances, only unbound elements can be provided (no PyTrees).

insert_transform(index: int, transform: Transform, *, invert: bool = False) ReplayBuffer[source]

Inserts transform.

Transforms are executed in order when sample is called.

Parameters:
  • index (int) – Position to insert the transform.

  • transform (Transform) – The transform to be appended

Keyword Arguments:

invert (bool, optional) – if True, the transform will be inverted (forward calls will be called during writing and inverse calls during reading). Defaults to False.

load(*args, **kwargs)[source]

Alias for loads().

loads(path)[source]

Loads a replay buffer state at the given path.

The buffer should have matching components and be saved using dumps().

Parameters:

path (Path or str) – path where the replay buffer was saved.

See dumps() for more info.

register_load_hook(hook: Callable[[Any], Any])[source]

Registers a load hook for the storage.

Note

Hooks are currently not serialized when saving a replay buffer: they must be manually re-initialized every time the buffer is created.

register_save_hook(hook: Callable[[Any], Any])[source]

Registers a save hook for the storage.

Note

Hooks are currently not serialized when saving a replay buffer: they must be manually re-initialized every time the buffer is created.

sample(batch_size: Optional[int] = None, return_info: bool = False) Any[source]

Samples a batch of data from the replay buffer.

Uses Sampler to sample indices, and retrieves them from Storage.

Parameters:
  • batch_size (int, optional) – size of data to be collected. If none is provided, this method will sample a batch-size as indicated by the sampler.

  • return_info (bool) – whether to return info. If True, the result is a tuple (data, info). If False, the result is the data.

Returns:

A batch of data selected in the replay buffer. A tuple containing this batch and info if return_info flag is set to True.

property sampler

The sampler of the replay buffer.

The sampler must be an instance of Sampler.

save(*args, **kwargs)[source]

Alias for dumps().

set_sampler(sampler: Sampler)[source]

Sets a new sampler in the replay buffer and returns the previous sampler.

set_storage(storage: Storage, collate_fn: Optional[Callable] = None)[source]

Sets a new storage in the replay buffer and returns the previous storage.

Parameters:
  • storage (Storage) – the new storage for the buffer.

  • collate_fn (callable, optional) – if provided, the collate_fn is set to this value. Otherwise it is reset to a default value.

set_writer(writer: Writer)[source]

Sets a new writer in the replay buffer and returns the previous writer.

property storage

The storage of the replay buffer.

The storage must be an instance of Storage.

property write_count

The total number of items written so far in the buffer through add and extend.

property writer

The writer of the replay buffer.

The writer must be an instance of Writer.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources