Shortcuts

Source code for torchrl.data.replay_buffers.writers

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import heapq
import json
import textwrap
from abc import ABC, abstractmethod
from copy import copy
from multiprocessing.context import get_spawning_popen
from pathlib import Path
from typing import Any, Dict, Sequence

import numpy as np
import torch

from tensordict import is_tensor_collection, MemoryMappedTensor, TensorDictBase
from tensordict.utils import _STRDTYPE2DTYPE, expand_as_right, is_tensorclass
from torch import multiprocessing as mp

try:
    from torch.utils._pytree import tree_leaves
except ImportError:
    from torch.utils._pytree import tree_flatten

    def tree_leaves(data):  # noqa: D103
        tree_flat, _ = tree_flatten(data)
        return tree_flat


from torchrl.data.replay_buffers.storages import Storage
from torchrl.data.replay_buffers.utils import _is_int, _reduce


[docs]class Writer(ABC): """A ReplayBuffer base Writer class.""" _storage: Storage def __init__(self) -> None: self._storage = None def register_storage(self, storage: Storage) -> None: self._storage = storage
[docs] @abstractmethod def add(self, data: Any) -> int: """Inserts one piece of data at an appropriate index, and returns that index.""" ...
[docs] @abstractmethod def extend(self, data: Sequence) -> torch.Tensor: """Inserts a series of data points at appropriate indices, and returns a tensor containing the indices.""" ...
@abstractmethod def _empty(self): ... @abstractmethod def dumps(self, path): ... @abstractmethod def loads(self, path): ... @abstractmethod def state_dict(self) -> Dict[str, Any]: ... @abstractmethod def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... def _replicate_index(self, index): # replicates the index in a non-zero format to have as many indices as # elements truly written when the storage is multidim if self._storage.ndim == 1: return index device = ( index.device if isinstance(index, torch.Tensor) else torch.device("cpu") ) mesh = torch.stack( torch.meshgrid( *(torch.arange(dim, device=device) for dim in self._storage.shape[1:]) ), -1, ).flatten(0, -2) if _is_int(index): index0 = torch.as_tensor(int(index)).expand(mesh.shape[0], 1) return torch.cat([index0, mesh], 1) return torch.cat( [ index.repeat_interleave(mesh.shape[0]).unsqueeze(1), mesh.repeat(index.numel(), 1), ], 1, ) def __repr__(self): return f"{self.__class__.__name__}()"
[docs]class ImmutableDatasetWriter(Writer): """A blocking writer for immutable datasets.""" WRITING_ERR = "This dataset doesn't allow writing."
[docs] def add(self, data: Any) -> int: raise RuntimeError(self.WRITING_ERR)
[docs] def extend(self, data: Sequence) -> torch.Tensor: raise RuntimeError(self.WRITING_ERR)
def _empty(self): raise RuntimeError(self.WRITING_ERR) def dumps(self, path): ... def loads(self, path): ... def state_dict(self) -> Dict[str, Any]: return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return
[docs]class RoundRobinWriter(Writer): """A RoundRobin Writer class for composable replay buffers.""" def __init__(self, **kw) -> None: super().__init__(**kw) self._cursor = 0 def dumps(self, path): path = Path(path).absolute() path.mkdir(exist_ok=True) with open(path / "metadata.json", "w") as file: json.dump({"cursor": self._cursor}, file) def loads(self, path): path = Path(path).absolute() with open(path / "metadata.json", "r") as file: metadata = json.load(file) self._cursor = metadata["cursor"]
[docs] def add(self, data: Any) -> int | torch.Tensor: index = self._cursor _cursor = self._cursor # we need to update the cursor first to avoid race conditions between workers self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0( single_data=data ) self._storage[_cursor] = data return self._replicate_index(index)
[docs] def extend(self, data: Sequence) -> torch.Tensor: cur_size = self._cursor if is_tensor_collection(data) or isinstance(data, torch.Tensor): batch_size = len(data) elif isinstance(data, list): batch_size = len(data) else: batch_size = len(tree_leaves(data)[0]) if batch_size == 0: raise RuntimeError("Expected at least one element in extend.") device = data.device if hasattr(data, "device") else None max_size_along0 = self._storage._max_size_along_dim0(batched_data=data) index = ( torch.arange( cur_size, batch_size + cur_size, dtype=torch.long, device=device ) % max_size_along0 ) # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % max_size_along0 self._storage[index] = data return self._replicate_index(index)
def state_dict(self) -> Dict[str, Any]: return {"_cursor": self._cursor} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._cursor = state_dict["_cursor"] def _empty(self): self._cursor = 0 @property def _cursor(self): _cursor_value = self.__dict__.get("_cursor_value", None) if _cursor_value is None: _cursor_value = self._cursor_value = mp.Value("i", 0) return _cursor_value.value @_cursor.setter def _cursor(self, value): _cursor_value = self.__dict__.get("_cursor_value", None) if _cursor_value is None: _cursor_value = self._cursor_value = mp.Value("i", 0) _cursor_value.value = value def __getstate__(self): state = copy(self.__dict__) if get_spawning_popen() is None: cursor = self._cursor del state["_cursor_value"] state["cursor__context"] = cursor return state def __setstate__(self, state): cursor = state.pop("cursor__context", None) if cursor is not None: _cursor_value = mp.Value("i", cursor) state["_cursor_value"] = _cursor_value self.__dict__.update(state) def __repr__(self): return f"{self.__class__.__name__}(cursor={int(self._cursor)}, full_storage={self._storage._is_full})"
[docs]class TensorDictRoundRobinWriter(RoundRobinWriter): """A RoundRobin Writer class for composable, tensordict-based replay buffers."""
[docs] def add(self, data: Any) -> int | torch.Tensor: index = self._cursor # we need to update the cursor first to avoid race conditions between workers max_size_along_dim0 = self._storage._max_size_along_dim0(single_data=data) self._cursor = (index + 1) % max_size_along_dim0 if not is_tensorclass(data): data.set( "index", expand_as_right( torch.as_tensor(index, device=data.device, dtype=torch.long), data ), ) self._storage[index] = data return self._replicate_index(index)
[docs] def extend(self, data: Sequence) -> torch.Tensor: cur_size = self._cursor batch_size = len(data) device = data.device if hasattr(data, "device") else None max_size_along_dim0 = self._storage._max_size_along_dim0(batched_data=data) index = ( torch.arange( cur_size, batch_size + cur_size, dtype=torch.long, device=device ) % max_size_along_dim0 ) # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % max_size_along_dim0 # storage must convert the data to the appropriate format if needed if not is_tensorclass(data): data.set( "index", expand_as_right( torch.as_tensor(index, device=data.device, dtype=torch.long), data ), ) self._storage[index] = data return self._replicate_index(index)
[docs]class TensorDictMaxValueWriter(Writer): """A Writer class for composable replay buffers that keeps the top elements based on some ranking key. Args: rank_key (str or tuple of str): the key to rank the elements by. Defaults to ``("next", "reward")``. reduction (str): the reduction method to use if the rank key has more than one element. Can be ``"max"``, ``"min"``, ``"mean"``, ``"median"`` or ``"sum"``. Examples: >>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter >>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(1), ... sampler=SamplerWithoutReplacement(), ... batch_size=1, ... writer=TensorDictMaxValueWriter(rank_key="key"), ... ) >>> td = TensorDict({ ... "key": torch.tensor(range(10)), ... "obs": torch.tensor(range(10)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 9 >>> td = TensorDict({ ... "key": torch.tensor(range(10, 20)), ... "obs": torch.tensor(range(10, 20)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 19 >>> td = TensorDict({ ... "key": torch.tensor(range(10)), ... "obs": torch.tensor(range(10)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 19 .. note:: This class isn't compatible with storages with more than one dimension. This doesn't mean that storing trajectories is prohibited, but that the trajectories stored must be stored on a per-trajectory basis. Here are some examples of valid and invalid usages of the class. First, a flat buffer where we store individual transitions: >>> from torchrl.data import TensorStorage >>> # Simplest use case: data comes in 1d and is stored as such >>> data = TensorDict({ ... "obs": torch.zeros(10, 3), ... "reward": torch.zeros(10, 1), ... }, batch_size=[10]) >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *transitions* can be stored >>> rb.extend(data) >>> # Samples 5 *transitions* at random >>> sample = rb.sample(5) >>> assert sample.shape == (5,) Second, a buffer where we store trajectories. The max signal is aggregated in each batch (e.g. the reward of each rollout is summed): >>> # One can also store batches of data, each batch being a sub-trajectory >>> env = ParallelEnv(2, lambda: GymEnv("Pendulum-v1")) >>> # Get a batch of [2, 10] -- format is [Batch, Time] >>> rollout = env.rollout(max_steps=10) >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *trajectories* (!) can be stored >>> rb.extend(rollout) >>> # Sample 5 trajectories at random >>> sample = rb.sample(5) >>> assert sample.shape == (5, 10) If data come in batch but a flat buffer is needed, we can simply flatten the data before extending the buffer: >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *transitions* can be stored >>> rb.extend(rollout.reshape(-1)) >>> # Sample 5 trajectories at random >>> sample = rb.sample(5) >>> assert sample.shape == (5,) It is not possible to create a buffer that is extended along the time dimension, which is usually the recommended way of using buffers with batches of trajectories. Since trajectories are overlapping, it's hard if not impossible to aggregate the reward values and compare them. This constructor isn't valid (notice the ndim argument): >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100, ndim=2), # Breaks! ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) """ def __init__(self, rank_key=None, reduction: str = "sum", **kwargs) -> None: super().__init__(**kwargs) self._cursor = 0 self._current_top_values = [] self._rank_key = rank_key self._reduction = reduction if self._rank_key is None: self._rank_key = ("next", "reward") def register_storage(self, storage: Storage) -> None: if storage.ndim > 1: raise ValueError( "TensorDictMaxValueWriter is not compatible with storages with more than one dimension. " "See the docstring constructor note about storing trajectories with TensorDictMaxValueWriter." ) return super().register_storage(storage)
[docs] def get_insert_index(self, data: Any) -> int: """Returns the index where the data should be inserted, or ``None`` if it should not be inserted.""" if not is_tensor_collection(data): raise RuntimeError( f"{type(self)} expects data to be a tensor collection (tensordict or tensorclass). Found a {type(data)} instead." ) if data.batch_dims > 1: raise RuntimeError( "Expected input tensordict to have no more than 1 dimension, got" f"tensordict.batch_size = {data.batch_size}" ) ret = None rank_data = data.get(self._rank_key) # If time dimension, sum along it. if rank_data.numel() > 1: rank_data = _reduce(rank_data.reshape(-1), self._reduction, dim=0) else: rank_data = rank_data.item() if rank_data is None: raise KeyError(f"Rank key {self._rank_key} not found in data.") # If the buffer is not full, add the data if len(self._current_top_values) < self._storage.max_size: ret = self._cursor self._cursor = (self._cursor + 1) % self._storage.max_size # Add new reward to the heap heapq.heappush(self._current_top_values, (rank_data, ret)) # If the buffer is full, check if the new data is better than the worst data in the buffer elif rank_data > self._current_top_values[0][0]: # retrieve position of the smallest value min_sample = heapq.heappop(self._current_top_values) ret = min_sample[1] # Add new reward to the heap heapq.heappush(self._current_top_values, (rank_data, ret)) return ret
[docs] def add(self, data: Any) -> int | torch.Tensor: """Inserts a single element of data at an appropriate index, and returns that index. The ``rank_key`` in the data passed to this module should be structured as []. If it has more dimensions, it will be reduced to a single value using the ``reduction`` method. """ index = self.get_insert_index(data) if index is not None: data.set("index", index) self._storage[index] = data return self._replicate_index(index)
[docs] def extend(self, data: TensorDictBase) -> None: """Inserts a series of data points at appropriate indices. The ``rank_key`` in the data passed to this module should be structured as [B]. If it has more dimensions, it will be reduced to a single value using the ``reduction`` method. """ # a map of [idx_in_storage, idx_in_data] data_to_replace = {} for data_idx, sample in enumerate(data): storage_idx = self.get_insert_index(sample) if storage_idx is not None: data_to_replace[storage_idx] = data_idx # -1 will be interpreted as invalid by prioritized buffers # Replace the data in the storage all at once if len(data_to_replace) > 0: storage_idx, data_idx = zip(*data_to_replace.items()) index = data.get("index", None) dtype = index.dtype if index is not None else torch.long device = index.device if index is not None else data.device out_index = torch.full(data.shape, -1, dtype=torch.long, device=device) data_idx = torch.as_tensor(data_idx, dtype=dtype, device=device) storage_idx = torch.as_tensor(storage_idx, dtype=dtype, device=device) out_index[data_idx] = storage_idx self._storage.set(storage_idx, data[data_idx]) else: device = getattr(self._storage, "device", None) out_index = torch.full(data.shape, -1, dtype=torch.long, device=device) return self._replicate_index(out_index)
def _empty(self) -> None: self._cursor = 0 self._current_top_values = [] def __getstate__(self): if get_spawning_popen() is not None: raise RuntimeError( f"Writers of type {type(self)} cannot be shared between processes." ) state = copy(self.__dict__) return state def dumps(self, path): path = Path(path).absolute() path.mkdir(exist_ok=True) t = torch.as_tensor(self._current_top_values) try: MemoryMappedTensor.from_filename( filename=path / "current_top_values.memmap", shape=t.shape, dtype=t.dtype, ).copy_(t) except FileNotFoundError: MemoryMappedTensor.from_tensor( t, filename=path / "current_top_values.memmap" ) with open(path / "metadata.json", "w") as file: json.dump( { "cursor": self._cursor, "rank_key": self._rank_key, "dtype": str(t.dtype), "shape": list(t.shape), }, file, ) def loads(self, path): path = Path(path).absolute() with open(path / "metadata.json", "r") as file: metadata = json.load(file) self._cursor = metadata["cursor"] self._rank_key = metadata["rank_key"] shape = torch.Size(metadata["shape"]) dtype = metadata["dtype"] self._current_top_values = MemoryMappedTensor.from_filename( filename=path / "current_top_values.memmap", dtype=_STRDTYPE2DTYPE[dtype], shape=shape, ).tolist() def state_dict(self) -> Dict[str, Any]: raise NotImplementedError def load_state_dict(self, state_dict: Dict[str, Any]) -> None: raise NotImplementedError def __repr__(self): return f"{self.__class__.__name__}(cursor={int(self._cursor)}, full_storage={self._storage._is_full}, rank_key={self._rank_key}, reduction={self._reduction})"
[docs]class WriterEnsemble(Writer): """An ensemble of writers. This class is designed to work with :class:`~torchrl.data.replay_buffers.replay_buffers.ReplayBufferEnsemble`. It contains the writers but blocks writing with any of them. Args: writers (sequence of Writer): the writers to make the composite writer. .. warning:: This class does not support writing. To extend one of the replay buffers, simply index the parent :class:`~torchrl.data.ReplayBufferEnsemble` object. """ def __init__(self, *writers): self._writers = writers def _empty(self): raise NotImplementedError def dumps(self, path: Path): path = Path(path).absolute() for i, writer in enumerate(self._writers): writer.dumps(path / str(i)) def loads(self, path: Path): path = Path(path).absolute() for i, writer in enumerate(self._writers): writer.loads(path / str(i))
[docs] def add(self): raise NotImplementedError
[docs] def extend(self): raise NotImplementedError
_INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead." def __getitem__(self, index): if isinstance(index, tuple): if index[0] is Ellipsis: index = (slice(None), index[1:]) result = self[index[0]] if len(index) > 1: raise IndexError( f"Tuple of length greater than 1 are not accepted to index writers of type {type(self)}." ) return result if isinstance(index, slice) and index == slice(None): return self if isinstance(index, (list, range, np.ndarray)): index = torch.as_tensor(index) if isinstance(index, torch.Tensor): if index.ndim > 1: raise RuntimeError( f"Cannot index a {type(self)} with tensor indices that have more than one dimension." ) if index.is_floating_point(): raise TypeError( "A floating point index was recieved when an integer dtype was expected." ) if isinstance(index, int) or (not isinstance(index, slice) and len(index) == 0): try: index = int(index) except Exception: raise IndexError(self._INDEX_ERROR.format(type(index))) try: return self._writers[index] except IndexError: raise IndexError(self._INDEX_ERROR.format(type(index))) if isinstance(index, torch.Tensor): index = index.tolist() writers = [self._writers[i] for i in index] else: # slice writers = self._writers[index] return WriterEnsemble(*writers) def __len__(self): return len(self._writers) def __repr__(self): writers = textwrap.indent(f"writers={self._writers}", " " * 4) return f"WriterEnsemble(\n{writers})" def state_dict(self) -> Dict[str, Any]: raise NotImplementedError def load_state_dict(self, state_dict: Dict[str, Any]) -> None: raise NotImplementedError

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources