Shortcuts

ReplayBufferEnsemble

class torchrl.data.replay_buffers.ReplayBufferEnsemble(*rbs, storages: StorageEnsemble | None = None, samplers: SamplerEnsemble | None = None, writers: WriterEnsemble | None = None, transform: 'Transform' | None = None, batch_size: int | None = None, collate_fn: Callable | None = None, collate_fns: List[Callable] | None = None, p: Tensor = None, sample_from_all: bool = False, num_buffer_sampled: int | None = None, **kwargs)[source]

An ensemble of replay buffers.

This class allows to read and sample from multiple replay buffers at once. It automatically composes ensemble of storages (StorageEnsemble), writers (WriterEnsemble) and samplers (SamplerEnsemble).

Note

Writing directly to this class is forbidden, but it can be indexed to retrieve the nested nested-buffer and extending it.

There are two distinct ways of constructing a ReplayBufferEnsemble: one can either pass a list of replay buffers, or directly pass the components (storage, writers and samplers) like it is done for other replay buffer subclasses.

Parameters:
  • rbs (sequence of ReplayBuffer instances, optional) – the replay buffers to ensemble.

  • storages (StorageEnsemble, optional) – the ensemble of storages, if the replay buffers are not passed.

  • samplers (SamplerEnsemble, optional) – the ensemble of samplers, if the replay buffers are not passed.

  • writers (WriterEnsemble, optional) – the ensemble of writers, if the replay buffers are not passed.

  • transform (Transform, optional) – if passed, this will be the transform of the ensemble of replay buffers. Individual transforms for each replay buffer is retrieved from its parent replay buffer, or directly written in the StorageEnsemble object.

  • batch_size (int, optional) – the batch-size to use during sampling.

  • collate_fn (callable, optional) – the function to use to collate the data after each individual collate_fn has been called and the data is placed in a list (along with the buffer id).

  • collate_fns (list of callables, optional) – collate_fn of each nested replay buffer. Retrieved from the ReplayBuffer instances if not provided.

  • p (list of float or Tensor, optional) – a list of floating numbers indicating the relative weight of each replay buffer. Can also be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble` if the buffer is built explicitely.

  • sample_from_all (bool, optional) – if True, each dataset will be sampled from. This is not compatible with the p argument. Defaults to False. Can also be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble` if the buffer is built explicitely.

  • num_buffer_sampled (int, optional) – the number of buffers to sample. if sample_from_all=True, this has no effect, as it defaults to the number of buffers. If sample_from_all=False, buffers will be sampled according to the probabilities p. Can also be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble` if the buffer is built explicitely.

Examples

>>> from torchrl.envs import Compose, ToTensorImage, Resize, RenameTransform
>>> from torchrl.data import TensorDictReplayBuffer, ReplayBufferEnsemble, LazyMemmapStorage
>>> from tensordict import TensorDict
>>> import torch
>>> rb0 = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(10),
...     transform=Compose(
...         ToTensorImage(in_keys=["pixels", ("next", "pixels")]),
...         Resize(32, in_keys=["pixels", ("next", "pixels")]),
...         RenameTransform([("some", "key")], ["renamed"]),
...     ),
... )
>>> rb1 = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(10),
...     transform=Compose(
...         ToTensorImage(in_keys=["pixels", ("next", "pixels")]),
...         Resize(32, in_keys=["pixels", ("next", "pixels")]),
...         RenameTransform(["another_key"], ["renamed"]),
...     ),
... )
>>> rb = ReplayBufferEnsemble(
...     rb0,
...     rb1,
...     p=[0.5, 0.5],
...     transform=Resize(33, in_keys=["pixels"], out_keys=["pixels33"]),
... )
>>> print(rb)
ReplayBufferEnsemble(
    storages=StorageEnsemble(
        storages=(<torchrl.data.replay_buffers.storages.LazyMemmapStorage object at 0x13a2ef430>, <torchrl.data.replay_buffers.storages.LazyMemmapStorage object at 0x13a2f9310>),
        transforms=[Compose(
                ToTensorImage(keys=['pixels', ('next', 'pixels')]),
                Resize(w=32, h=32, interpolation=InterpolationMode.BILINEAR, keys=['pixels', ('next', 'pixels')]),
                RenameTransform(keys=[('some', 'key')])), Compose(
                ToTensorImage(keys=['pixels', ('next', 'pixels')]),
                Resize(w=32, h=32, interpolation=InterpolationMode.BILINEAR, keys=['pixels', ('next', 'pixels')]),
                RenameTransform(keys=['another_key']))]),
    samplers=SamplerEnsemble(
        samplers=(<torchrl.data.replay_buffers.samplers.RandomSampler object at 0x13a2f9220>, <torchrl.data.replay_buffers.samplers.RandomSampler object at 0x13a2f9f70>)),
    writers=WriterEnsemble(
        writers=(<torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter object at 0x13a2d9b50>, <torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter object at 0x13a2f95b0>)),
batch_size=None,
transform=Compose(
        Resize(w=33, h=33, interpolation=InterpolationMode.BILINEAR, keys=['pixels'])),
collate_fn=<built-in method stack of type object at 0x128648260>)
>>> data0 = TensorDict(
...     {
...         "pixels": torch.randint(255, (10, 244, 244, 3)),
...         ("next", "pixels"): torch.randint(255, (10, 244, 244, 3)),
...         ("some", "key"): torch.randn(10),
...     },
...     batch_size=[10],
... )
>>> data1 = TensorDict(
...     {
...         "pixels": torch.randint(255, (10, 64, 64, 3)),
...         ("next", "pixels"): torch.randint(255, (10, 64, 64, 3)),
...         "another_key": torch.randn(10),
...     },
...     batch_size=[10],
... )
>>> rb[0].extend(data0)
>>> rb[1].extend(data1)
>>> for _ in range(2):
...     sample = rb.sample(10)
...     assert sample["next", "pixels"].shape == torch.Size([2, 5, 3, 32, 32])
...     assert sample["pixels"].shape == torch.Size([2, 5, 3, 32, 32])
...     assert sample["pixels33"].shape == torch.Size([2, 5, 3, 33, 33])
...     assert sample["renamed"].shape == torch.Size([2, 5])
add(data: Any) int

Add a single element to the replay buffer.

Parameters:

data (Any) – data to be added to the replay buffer

Returns:

index where the data lives in the replay buffer.

append_transform(transform: Transform, *, invert: bool = False) ReplayBuffer

Appends transform at the end.

Transforms are applied in order when sample is called.

Parameters:

transform (Transform) – The transform to be appended

Keyword Arguments:

invert (bool, optional) – if True, the transform will be inverted (forward calls will be called during writing and inverse calls during reading). Defaults to False.

Example

>>> rb = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=4)
>>> data = TensorDict({"a": torch.zeros(10)}, [10])
>>> def t(data):
...     data += 1
...     return data
>>> rb.append_transform(t, invert=True)
>>> rb.extend(data)
>>> assert (data == 1).all()
dump(*args, **kwargs)

Alias for dumps().

dumps(path)

Saves the replay buffer on disk at the specified path.

Parameters:

path (Path or str) – path where to save the replay buffer.

Examples

>>> import tempfile
>>> import tqdm
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
>>> import torch
>>> from tensordict import TensorDict
>>> # Build and populate the replay buffer
>>> S = 1_000_000
>>> sampler = PrioritizedSampler(S, 1.1, 1.0)
>>> # sampler = RandomSampler()
>>> storage = LazyMemmapStorage(S)
>>> rb = TensorDictReplayBuffer(storage=storage, sampler=sampler)
>>>
>>> for _ in tqdm.tqdm(range(100)):
...     td = TensorDict({"obs": torch.randn(100, 3, 4), "next": {"obs": torch.randn(100, 3, 4)}, "td_error": torch.rand(100)}, [100])
...     rb.extend(td)
...     sample = rb.sample(32)
...     rb.update_tensordict_priority(sample)
>>> # save and load the buffer
>>> with tempfile.TemporaryDirectory() as tmpdir:
...     rb.dumps(tmpdir)
...
...     sampler = PrioritizedSampler(S, 1.1, 1.0)
...     # sampler = RandomSampler()
...     storage = LazyMemmapStorage(S)
...     rb_load = TensorDictReplayBuffer(storage=storage, sampler=sampler)
...     rb_load.loads(tmpdir)
...     assert len(rb) == len(rb_load)
empty()

Empties the replay buffer and reset cursor to 0.

extend(data: Sequence) Tensor

Extends the replay buffer with one or more elements contained in an iterable.

If present, the inverse transforms will be called.`

Parameters:

data (iterable) – collection of data to be added to the replay buffer.

Returns:

Indices of the data added to the replay buffer.

Warning

extend() can have an ambiguous signature when dealing with lists of values, which should be interpreted either as PyTree (in which case all elements in the list will be put in a slice in the stored PyTree in the storage) or a list of values to add one at a time. To solve this, TorchRL makes the clear-cut distinction between list and tuple: a tuple will be viewed as a PyTree, a list (at the root level) will be interpreted as a stack of values to add one at a time to the buffer. For ListStorage instances, only unbound elements can be provided (no PyTrees).

insert_transform(index: int, transform: Transform, *, invert: bool = False) ReplayBuffer

Inserts transform.

Transforms are executed in order when sample is called.

Parameters:
  • index (int) – Position to insert the transform.

  • transform (Transform) – The transform to be appended

Keyword Arguments:

invert (bool, optional) – if True, the transform will be inverted (forward calls will be called during writing and inverse calls during reading). Defaults to False.

load(*args, **kwargs)

Alias for loads().

loads(path)

Loads a replay buffer state at the given path.

The buffer should have matching components and be saved using dumps().

Parameters:

path (Path or str) – path where the replay buffer was saved.

See dumps() for more info.

register_load_hook(hook: Callable[[Any], Any])

Registers a load hook for the storage.

Note

Hooks are currently not serialized when saving a replay buffer: they must be manually re-initialized every time the buffer is created.

register_save_hook(hook: Callable[[Any], Any])

Registers a save hook for the storage.

Note

Hooks are currently not serialized when saving a replay buffer: they must be manually re-initialized every time the buffer is created.

sample(batch_size: int | None = None, return_info: bool = False) Any

Samples a batch of data from the replay buffer.

Uses Sampler to sample indices, and retrieves them from Storage.

Parameters:
  • batch_size (int, optional) – size of data to be collected. If none is provided, this method will sample a batch-size as indicated by the sampler.

  • return_info (bool) – whether to return info. If True, the result is a tuple (data, info). If False, the result is the data.

Returns:

A batch of data selected in the replay buffer. A tuple containing this batch and info if return_info flag is set to True.

property sampler

The sampler of the replay buffer.

The sampler must be an instance of Sampler.

save(*args, **kwargs)

Alias for dumps().

set_sampler(sampler: Sampler)

Sets a new sampler in the replay buffer and returns the previous sampler.

set_storage(storage: Storage, collate_fn: Callable | None = None)

Sets a new storage in the replay buffer and returns the previous storage.

Parameters:
  • storage (Storage) – the new storage for the buffer.

  • collate_fn (callable, optional) – if provided, the collate_fn is set to this value. Otherwise it is reset to a default value.

set_writer(writer: Writer)

Sets a new writer in the replay buffer and returns the previous writer.

property storage

The storage of the replay buffer.

The storage must be an instance of Storage.

property writer

The writer of the replay buffer.

The writer must be an instance of Writer.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources