Shortcuts

Source code for torchrl.data.replay_buffers.utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# import tree
from __future__ import annotations

import contextlib
import itertools

import math
import operator
import os
import typing
from pathlib import Path
from typing import Any, Callable, Union

import numpy as np
import torch
from tensordict import (
    LazyStackedTensorDict,
    MemoryMappedTensor,
    NonTensorData,
    TensorDict,
    TensorDictBase,
    unravel_key,
)
from torch import Tensor
from torch.nn import functional as F
from torch.utils._pytree import LeafSpec, tree_flatten, tree_unflatten
from torchrl._utils import implement_for, logger as torchrl_logger

SINGLE_TENSOR_BUFFER_NAME = os.environ.get(
    "SINGLE_TENSOR_BUFFER_NAME", "_-single-tensor-_"
)


INT_CLASSES_TYPING = Union[int, np.integer]
if hasattr(typing, "get_args"):
    INT_CLASSES = typing.get_args(INT_CLASSES_TYPING)
else:
    # python 3.7
    INT_CLASSES = (int, np.integer)


def _to_numpy(data: Tensor) -> np.ndarray:
    return data.detach().cpu().numpy() if isinstance(data, torch.Tensor) else data


def _to_torch(
    data: Tensor, device, pin_memory: bool = False, non_blocking: bool = False
) -> torch.Tensor:
    if isinstance(data, np.generic):
        return torch.as_tensor(data, device=device)
    elif isinstance(data, np.ndarray):
        data = torch.from_numpy(data)
    elif not isinstance(data, Tensor):
        data = torch.as_tensor(data, device=device)

    if pin_memory:
        data = data.pin_memory()
    if device is not None:
        data = data.to(device, non_blocking=non_blocking)

    return data


def pin_memory_output(fun) -> Callable:
    """Calls pin_memory on outputs of decorated function if they have such method."""

    def decorated_fun(self, *args, **kwargs):
        output = fun(self, *args, **kwargs)
        if self._pin_memory:
            _tuple_out = True
            if not isinstance(output, tuple):
                _tuple_out = False
                output = (output,)
            output = tuple(_pin_memory(_output) for _output in output)
            if _tuple_out:
                return output
            return output[0]
        return output

    return decorated_fun


def _pin_memory(output: Any) -> Any:
    if hasattr(output, "pin_memory") and output.device == torch.device("cpu"):
        return output.pin_memory()
    else:
        return output


def _reduce(
    tensor: torch.Tensor, reduction: str, dim: int | None = None
) -> Union[float, torch.Tensor]:
    """Reduces a tensor given the reduction method."""
    if reduction == "max":
        result = tensor.max(dim=dim)
    elif reduction == "min":
        result = tensor.min(dim=dim)
    elif reduction == "mean":
        result = tensor.mean(dim=dim)
    elif reduction == "median":
        result = tensor.median(dim=dim)
    elif reduction == "sum":
        result = tensor.sum(dim=dim)
    else:
        raise NotImplementedError(f"Unknown reduction method {reduction}")
    if isinstance(result, tuple):
        result = result[0]
    return result.item() if dim is None else result


def _is_int(index):
    if isinstance(index, INT_CLASSES):
        return True
    if isinstance(index, (np.ndarray, torch.Tensor)):
        return index.ndim == 0
    return False


[docs]class TED2Flat: """A storage saving hook to serialize TED data in a compact format. Args: done_key (NestedKey, optional): the key where the done states should be read. Defaults to ``("next", "done")``. shift_key (NestedKey, optional): the key where the shift will be written. Defaults to "shift". is_full_key (NestedKey, optional): the key where the is_full attribute will be written. Defaults to "is_full". done_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the done entries. Defaults to ("done", "truncated", "terminated") reward_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the reward entries. Defaults to ("reward",) Examples: >>> import tempfile >>> >>> from tensordict import TensorDict >>> >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage >>> from torchrl.envs import GymEnv >>> import torch >>> >>> env = GymEnv("CartPole-v1") >>> env.set_seed(0) >>> torch.manual_seed(0) >>> collector = SyncDataCollector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200) >>> rb = ReplayBuffer(storage=LazyMemmapStorage(200)) >>> rb.register_save_hook(TED2Flat()) >>> with tempfile.TemporaryDirectory() as tmpdir: ... for i, data in enumerate(collector): ... rb.extend(data) ... rb.dumps(tmpdir) ... # load the data to represent it ... td = TensorDict.load(tmpdir + "/storage/") ... print(td) TensorDict( fields={ action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=True), collector: TensorDict( fields={ traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cpu, is_shared=False), done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True), observation: MemoryMappedTensor(shape=torch.Size([220, 4]), device=cpu, dtype=torch.float32, is_shared=True), reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=True), terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True), truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True)}, batch_size=torch.Size([]), device=cpu, is_shared=False) """ _shift: int = None _is_full: bool = None def __init__( self, done_key=("next", "done"), shift_key="shift", is_full_key="is_full", done_keys=("done", "truncated", "terminated"), reward_keys=("reward",), ): self.done_key = done_key self.shift_key = shift_key self.is_full_key = is_full_key self.done_keys = {unravel_key(key) for key in done_keys} self.reward_keys = {unravel_key(key) for key in reward_keys} @property def shift(self): return self._shift @shift.setter def shift(self, value: int): self._shift = value @property def is_full(self): return self._is_full @is_full.setter def is_full(self, value: int): self._is_full = value def __call__(self, data: TensorDictBase, path: Path = None): # Get the done state shift = self.shift is_full = self.is_full # Create an output storage output = TensorDict() output.set_non_tensor(self.is_full_key, is_full) output.set_non_tensor(self.shift_key, shift) output.set_non_tensor("_storage_shape", tuple(data.shape)) output.memmap_(path) # Preallocate the output done = data.get(self.done_key).squeeze(-1).clone() if not is_full: # shift is the cursor place done[shift - 1] = True else: done = done.roll(-shift, dims=0) done[-1] = True ntraj = done.sum() # Get the keys that require extra storage keys_to_expand = set(data.get("next").keys(True, True)) - ( self.done_keys.union(self.reward_keys) ) total_keys = data.exclude("next").keys(True, True) total_keys = set(total_keys).union(set(data.get("next").keys(True, True))) len_with_offset = data.numel() + ntraj # + done[0].numel() for key in total_keys: if key in (self.done_keys.union(self.reward_keys)): entry = data.get(("next", key)) else: entry = data.get(key) if key in keys_to_expand: shape = torch.Size([len_with_offset, *entry.shape[data.ndim :]]) dtype = entry.dtype output.make_memmap(key, shape=shape, dtype=dtype) else: shape = torch.Size([data.numel(), *entry.shape[data.ndim :]]) output.make_memmap(key, shape=shape, dtype=entry.dtype) if data.ndim == 1: return self._call( data=data, output=output, is_full=is_full, shift=shift, done=done, total_keys=total_keys, keys_to_expand=keys_to_expand, ) with data.flatten(1, -1) if data.ndim > 2 else contextlib.nullcontext( data ) as data_flat: if data.ndim > 2: done = done.flatten(1, -1) traj_per_dim = done.sum(0) nsteps = data_flat.shape[0] start = 0 start_with_offset = start stop_with_offset = 0 stop = 0 for data_slice, done_slice, traj_for_dim in zip( data_flat.unbind(1), done.unbind(1), traj_per_dim ): stop_with_offset = stop_with_offset + nsteps + traj_for_dim cur_slice_offset = slice(start_with_offset, stop_with_offset) start_with_offset = stop_with_offset stop = stop + data.shape[0] cur_slice = slice(start, stop) start = stop def _index( key, val, keys_to_expand=keys_to_expand, cur_slice=cur_slice, cur_slice_offset=cur_slice_offset, ): if key in keys_to_expand: return val[cur_slice_offset] return val[cur_slice] out_slice = output.named_apply(_index, nested_keys=True) self._call( data=data_slice, output=out_slice, is_full=is_full, shift=shift, done=done_slice, total_keys=total_keys, keys_to_expand=keys_to_expand, ) return output def _call(self, *, data, output, is_full, shift, done, total_keys, keys_to_expand): # capture for each item in data where the observation should be written idx = torch.arange(data.shape[0]) idx_done = (idx + done.cumsum(0))[done] idx += torch.nn.functional.pad(done, [1, 0])[:-1].cumsum(0) for key in total_keys: if key in (self.done_keys.union(self.reward_keys)): entry = data.get(("next", key)) else: entry = data.get(key) if key in keys_to_expand: mmap = output.get(key) shifted_next = data.get(("next", key)) if is_full: _roll_inplace(entry, shift=-shift, out=mmap, index_dest=idx) _roll_inplace( shifted_next, shift=-shift, out=mmap, index_dest=idx_done, index_source=done, ) else: mmap[idx] = entry mmap[idx_done] = shifted_next[done] elif is_full: mmap = output.get(key) _roll_inplace(entry, shift=-shift, out=mmap) else: mmap = output.get(key) mmap.copy_(entry) return output
[docs]class Flat2TED: """A storage loading hook to deserialize flattened TED data to TED format. Args: done_key (NestedKey, optional): the key where the done states should be read. Defaults to ``("next", "done")``. shift_key (NestedKey, optional): the key where the shift will be written. Defaults to "shift". is_full_key (NestedKey, optional): the key where the is_full attribute will be written. Defaults to "is_full". done_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the done entries. Defaults to ("done", "truncated", "terminated") reward_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the reward entries. Defaults to ("reward",) Examples: >>> import tempfile >>> >>> from tensordict import TensorDict >>> >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage, Flat2TED >>> from torchrl.envs import GymEnv >>> import torch >>> >>> env = GymEnv("CartPole-v1") >>> env.set_seed(0) >>> torch.manual_seed(0) >>> collector = SyncDataCollector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200) >>> rb = ReplayBuffer(storage=LazyMemmapStorage(200)) >>> rb.register_save_hook(TED2Flat()) >>> with tempfile.TemporaryDirectory() as tmpdir: ... for i, data in enumerate(collector): ... rb.extend(data) ... rb.dumps(tmpdir) ... # load the data to represent it ... td = TensorDict.load(tmpdir + "/storage/") ... ... rb_load = ReplayBuffer(storage=LazyMemmapStorage(200)) ... rb_load.register_load_hook(Flat2TED()) ... rb_load.load(tmpdir) ... print("storage after loading", rb_load[:]) ... assert (rb[:] == rb_load[:]).all() storage after loading TensorDict( fields={ action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False), collector: TensorDict( fields={ traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False), reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False), terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False) """ def __init__( self, done_key="done", shift_key="shift", is_full_key="is_full", done_keys=("done", "truncated", "terminated"), reward_keys=("reward",), ): self.done_key = done_key self.shift_key = shift_key self.is_full_key = is_full_key self.done_keys = {unravel_key(key) for key in done_keys} self.reward_keys = {unravel_key(key) for key in reward_keys} def __call__(self, data: TensorDictBase, out: TensorDictBase = None): _storage_shape = data.get_non_tensor("_storage_shape", default=None) if isinstance(_storage_shape, int): _storage_shape = torch.Size([_storage_shape]) shift = data.get_non_tensor(self.shift_key, default=None) is_full = data.get_non_tensor(self.is_full_key, default=None) done = ( data.get("done") .reshape((*_storage_shape[1:], -1)) .contiguous() .permute(-1, *range(0, len(_storage_shape) - 1)) .clone() ) if not is_full: # shift is the cursor place done[shift - 1] = True else: # done = done.roll(-shift, dims=0) done[-1] = True if _storage_shape is not None and len(_storage_shape) > 1: # iterate over data and allocate if out is None: # out = TensorDict(batch_size=_storage_shape) # for i in range(out.ndim): # if i >= 2: # # FLattening the lazy stack will make the data unavailable - we need to find a way to make this # # possible. # raise RuntimeError( # "Checkpointing an uninitialized buffer with more than 2 dimensions is currently not supported. " # "Please file an issue on GitHub to ask for this feature!" # ) # out = LazyStackedTensorDict(*out.unbind(i), stack_dim=i) out = TensorDict(batch_size=_storage_shape) for i in range(1, out.ndim): if i >= 2: # FLattening the lazy stack will make the data unavailable - we need to find a way to make this # possible. raise RuntimeError( "Checkpointing an uninitialized buffer with more than 2 dimensions is currently not supported. " "Please file an issue on GitHub to ask for this feature!" ) out_list = [ out._get_sub_tensordict((slice(None),) * i + (j,)) for j in range(out.shape[i]) ] out = LazyStackedTensorDict(*out_list, stack_dim=i) # Create a function that reads slices of the input data with out.flatten(1, -1) if out.ndim > 2 else contextlib.nullcontext( out ) as out_flat: nsteps = done.shape[0] n_elt_batch = done.shape[1:].numel() traj_per_dim = done.sum(0) start = 0 start_with_offset = start stop_with_offset = 0 stop = 0 for out_unbound, traj_for_dim in zip(out_flat.unbind(-1), traj_per_dim): stop_with_offset = stop_with_offset + nsteps + traj_for_dim cur_slice_offset = slice(start_with_offset, stop_with_offset) start_with_offset = stop_with_offset stop = stop + nsteps cur_slice = slice(start, stop) start = stop def _index( key, val, cur_slice=cur_slice, nsteps=nsteps, n_elt_batch=n_elt_batch, cur_slice_offset=cur_slice_offset, ): if val.shape[0] != (nsteps * n_elt_batch): return val[cur_slice_offset] return val[cur_slice] data_slice = data.named_apply( _index, nested_keys=True, batch_size=[] ) self._call( data=data_slice, out=out_unbound, is_full=is_full, shift=shift, _storage_shape=_storage_shape, ) return out return self._call( data=data, out=out, is_full=is_full, shift=shift, _storage_shape=_storage_shape, ) def _call(self, *, data, out, _storage_shape, shift, is_full): done = data.get(self.done_key) done = done.clone() nsteps = done.shape[0] # capture for each item in data where the observation should be written idx = torch.arange(done.shape[0]) padded_done = F.pad(done.squeeze(-1), [1, 0]) root_idx = idx + padded_done[:-1].cumsum(0) next_idx = root_idx + 1 if out is None: out = TensorDict({}, [nsteps]) def maybe_roll(entry, out=None): if is_full and shift is not None: if out is not None: _roll_inplace(entry, shift=shift, out=out) return else: return entry.roll(shift, dims=0) if out is not None: out.copy_(entry) return return entry root_idx = maybe_roll(root_idx) next_idx = maybe_roll(next_idx) if not is_full: next_idx = next_idx[:-1] for key, entry in data.items(True, True): if entry.shape[0] == nsteps: if key in (self.done_keys.union(self.reward_keys)): if key != "reward" and key not in out.keys(True, True): # Create a done state at the root full of 0s out.set(key, torch.zeros_like(entry), inplace=True) entry = maybe_roll(entry, out=out.get(("next", key), None)) if entry is not None: out.set(("next", key), entry, inplace=True) else: # action and similar entry = maybe_roll(entry, out=out.get(key, default=None)) if entry is not None: # then out is not locked out.set(key, entry, inplace=True) else: dest_next = out.get(("next", key), None) if dest_next is not None: if not is_full: dest_next = dest_next[:-1] dest_next.copy_(entry[next_idx]) else: if not is_full: val = entry[next_idx] val = torch.cat([val, torch.zeros_like(val[:1])]) out.set(("next", key), val, inplace=True) else: out.set(("next", key), entry[next_idx], inplace=True) dest = out.get(key, None) if dest is not None: dest.copy_(entry[root_idx]) else: out.set(key, entry[root_idx], inplace=True) return out
[docs]class TED2Nested(TED2Flat): """Converts a TED-formatted dataset into a tensordict populated with nested tensors where each row is a trajectory.""" _shift: int = None _is_full: bool = None def __init__(self, *args, **kwargs): if not hasattr(torch, "_nested_compute_contiguous_strides_offsets"): raise ValueError( f"Unsupported torch version {torch.__version__}. " f"torch>=2.4 is required for {type(self).__name__} to be used." ) return super().__init__(*args, **kwargs) def __call__(self, data: TensorDictBase, path: Path = None): data = super().__call__(data, path=path) shift = self.shift is_full = self.is_full storage_shape = data.get_non_tensor("_storage_shape", (-1,)) # place time at the end storage_shape = (*storage_shape[1:], storage_shape[0]) done = data.get("done") done = done.squeeze(-1).clone() if not is_full: done.view(storage_shape)[..., shift - 1] = True # else: done.view(storage_shape)[..., -1] = True ntraj = done.sum() nz = done.nonzero(as_tuple=True)[0] traj_lengths = torch.cat([nz[:1] + 1, nz.diff()]) # if not is_full: # traj_lengths = torch.cat( # [traj_lengths, (done.shape[0] - traj_lengths.sum()).unsqueeze(0)] # ) keys_to_expand, keys_to_keep = zip( *[ (key, None) if val.shape[0] != done.shape[0] else (None, key) for key, val in data.items(True, True) ] ) keys_to_expand = [key for key in keys_to_expand if key is not None] keys_to_keep = [key for key in keys_to_keep if key is not None] out = TensorDict({}, batch_size=[ntraj]) out.update(dict(data.non_tensor_items())) out.memmap_(path) traj_lengths = traj_lengths.unsqueeze(-1) if not is_full: # Increment by one only the trajectories that are not terminal traj_lengths_expand = traj_lengths + ( traj_lengths.cumsum(0) % storage_shape[-1] != 0 ) else: traj_lengths_expand = traj_lengths + 1 for key in keys_to_expand: val = data.get(key) shape = torch.cat( [ traj_lengths_expand, torch.tensor(val.shape[1:], dtype=torch.long).repeat( traj_lengths.numel(), 1 ), ], -1, ) # This works because the storage location is the same as the previous one - no copy is done # but a new shape is written out.make_memmap_from_storage( key, val.untyped_storage(), dtype=val.dtype, shape=shape ) for key in keys_to_keep: val = data.get(key) shape = torch.cat( [ traj_lengths, torch.tensor(val.shape[1:], dtype=torch.long).repeat( traj_lengths.numel(), 1 ), ], -1, ) out.make_memmap_from_storage( key, val.untyped_storage(), dtype=val.dtype, shape=shape ) return out
[docs]class Nested2TED(Flat2TED): """Converts a nested tensordict where each row is a trajectory into the TED format.""" def __call__(self, data, out: TensorDictBase = None): # Get a flat representation of data def flatten_het_dim(tensor): shape = [tensor.size(i) for i in range(2, tensor.ndim)] tensor = torch.tensor(tensor.untyped_storage(), dtype=tensor.dtype).view( -1, *shape ) return tensor data = data.apply(flatten_het_dim, batch_size=[]) data.auto_batch_size_() return super().__call__(data, out=out)
[docs]class H5Split(TED2Flat): """Splits a dataset prepared with TED2Nested into a TensorDict where each trajectory is stored as views on their parent nested tensors.""" _shift: int = None _is_full: bool = None def __call__(self, data): nzeros = int(math.ceil(math.log10(data.shape[0]))) result = TensorDict( { f"traj_{str(i).zfill(nzeros)}": _data for i, _data in enumerate(data.filter_non_tensor_data().unbind(0)) } ).update(dict(data.non_tensor_items())) return result
[docs]class H5Combine: """Combines trajectories in a persistent tensordict into a single standing tensordict stored in filesystem.""" def __call__(self, data, out=None): # TODO: this load the entire H5 in memory, which can be problematic # Ideally we would want to load it on a memmap tensordict # We currently ignore out in this call but we should leverage that values = [val for key, val in data.items() if key.startswith("traj")] metadata_keys = [key for key in data.keys() if not key.startswith("traj")] result = TensorDict({key: NonTensorData(data[key]) for key in metadata_keys}) # Create a memmap in file system (no files associated) result.memmap_() # Create each entry def initialize(key, *x): result.make_memmap( key, shape=torch.stack([torch.tensor(_x.shape) for _x in x]), dtype=x[0].dtype, ) return values[0].named_apply( initialize, *values[1:], nested_keys=True, batch_size=[], filter_empty=True, ) # Populate the entries def populate(key, *x): dest = result.get(key) for i, _x in enumerate(x): dest[i].copy_(_x) values[0].named_apply( populate, *values[1:], nested_keys=True, batch_size=[], filter_empty=True, ) return result
@implement_for("torch", "2.3", None) def _path2str(path, default_name=None): # Uses the Keys defined in pytree to build a path from torch.utils._pytree import MappingKey, SequenceKey if default_name is None: default_name = SINGLE_TENSOR_BUFFER_NAME if not path: return default_name if isinstance(path, tuple): return "/".join([_path2str(_sub, default_name=default_name) for _sub in path]) if isinstance(path, MappingKey): if not isinstance(path.key, (int, str, bytes)): raise ValueError("Values must be of type int, str or bytes in PyTree maps.") result = str(path.key) if result == default_name: raise RuntimeError( "A tensor had the same identifier as the default name used when the buffer contains " f"a single tensor (name={default_name}). This behavior is not allowed. Please rename your " f"tensor in the map/dict or set a new default name with the environment variable SINGLE_TENSOR_BUFFER_NAME." ) return result if isinstance(path, SequenceKey): return str(path.idx) @implement_for("torch", None, "2.3") def _path2str(path, default_name=None): # noqa: F811 raise RuntimeError def _save_pytree_common(tensor_path, path, tensor, metadata): if "." in tensor_path: tensor_path.replace(".", "_<dot>_") total_tensor_path = path / (tensor_path + ".memmap") if os.path.exists(total_tensor_path): MemoryMappedTensor.from_filename( shape=tensor.shape, filename=total_tensor_path, dtype=tensor.dtype, ).copy_(tensor) else: os.makedirs(total_tensor_path.parent, exist_ok=True) MemoryMappedTensor.from_tensor( tensor, filename=total_tensor_path, copy_existing=True, copy_data=True, ) key = tensor_path.replace("/", ".") if key in metadata: raise KeyError( "At least two values have conflicting representations in " f"the data structure to be serialized: {key}." ) metadata[key] = { "dtype": str(tensor.dtype), "shape": list(tensor.shape), } @implement_for("torch", "2.3", None) def _save_pytree(_storage, metadata, path): from torch.utils._pytree import tree_map_with_path def save_tensor( tensor_path: tuple, tensor: torch.Tensor, metadata=metadata, path=path ): tensor_path = _path2str(tensor_path) _save_pytree_common(tensor_path, path, tensor, metadata) tree_map_with_path(save_tensor, _storage) @implement_for("torch", None, "2.3") def _save_pytree(_storage, metadata, path): # noqa: F811 flat_storage, storage_specs = tree_flatten(_storage) storage_paths = _get_paths(storage_specs) def save_tensor( tensor_path: str, tensor: torch.Tensor, metadata=metadata, path=path ): _save_pytree_common(tensor_path, path, tensor, metadata) for tensor, tensor_path in zip(flat_storage, storage_paths): save_tensor(tensor_path, tensor) def _get_paths(spec, cumulpath=""): # alternative way to build a path without the keys if isinstance(spec, LeafSpec): yield cumulpath if cumulpath else SINGLE_TENSOR_BUFFER_NAME contexts = spec.context children_specs = spec.children_specs if contexts is None: contexts = range(len(children_specs)) for context, spec in zip(contexts, children_specs): cpath = "/".join((cumulpath, str(context))) if cumulpath else str(context) yield from _get_paths(spec, cpath) def _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor): if "." in tensor_path: tensor_path.replace(".", "_<dot>_") if scratch_dir is not None: total_tensor_path = Path(scratch_dir) / (tensor_path + ".memmap") if os.path.exists(total_tensor_path): raise RuntimeError( f"The storage of tensor {total_tensor_path} already exists. " f"To load an existing replay buffer, use storage.loads. " f"Choose a different path to store your buffer or delete the existing files." ) os.makedirs(total_tensor_path.parent, exist_ok=True) else: total_tensor_path = None out = MemoryMappedTensor.empty( shape=max_size_fn(tensor.shape), filename=total_tensor_path, dtype=tensor.dtype, ) try: filesize = os.path.getsize(tensor.filename) / 1024 / 1024 torchrl_logger.debug( f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." ) except (RuntimeError, AttributeError): pass return out @implement_for("torch", "2.3", None) def _init_pytree(scratch_dir, max_size_fn, data): from torch.utils._pytree import tree_map_with_path # If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype def save_tensor(tensor_path: tuple, tensor: torch.Tensor): tensor_path = _path2str(tensor_path) return _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor) out = tree_map_with_path(save_tensor, data) return out @implement_for("torch", None, "2.3") def _init_pytree(scratch_dir, max_size, data): # noqa: F811 flat_data, data_specs = tree_flatten(data) data_paths = _get_paths(data_specs) data_paths = list(data_paths) # If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype def save_tensor(tensor_path: str, tensor: torch.Tensor): return _init_pytree_common(tensor_path, scratch_dir, max_size, tensor) out = [] for tensor, tensor_path in zip(flat_data, data_paths): out.append(save_tensor(tensor_path, tensor)) return tree_unflatten(out, data_specs) def _roll_inplace(tensor, shift, out, index_dest=None, index_source=None): # slice 0 source0 = tensor[:-shift] if index_source is not None: source0 = source0[index_source[shift:]] slice0_shift = source0.shape[0] if index_dest is not None: out[index_dest[-slice0_shift:]] = source0 else: slice0 = out[-slice0_shift:] slice0.copy_(source0) # slice 1 source1 = tensor[-shift:] if index_source is not None: source1 = source1[index_source[:shift]] if index_dest is not None: out[index_dest[:-slice0_shift]] = source1 else: slice1 = out[:-slice0_shift] slice1.copy_(source1) return out # Copy-paste of unravel-index for PT 2.0 def _unravel_index( indices: Tensor, shape: Union[int, typing.Sequence[int], torch.Size] ) -> typing.Tuple[Tensor, ...]: res_tensor = _unravel_index_impl(indices, shape) return res_tensor.unbind(-1) def _unravel_index_impl( indices: Tensor, shape: Union[int, typing.Sequence[int]] ) -> Tensor: if isinstance(shape, (int, torch.SymInt)): shape = torch.Size([shape]) else: shape = torch.Size(shape) coefs = list( reversed( list( itertools.accumulate( reversed(shape[1:] + torch.Size([1])), func=operator.mul ) ) ) ) return indices.unsqueeze(-1).floor_divide( torch.tensor(coefs, device=indices.device, dtype=torch.int64) ) % torch.tensor(shape, device=indices.device, dtype=torch.int64) @implement_for("torch", None, "2.2") def unravel_index(indices, shape): """A version-compatible wrapper around torch.unravel_index.""" return _unravel_index(indices, shape) @implement_for("torch", "2.2") def unravel_index(indices, shape): # noqa: F811 """A version-compatible wrapper around torch.unravel_index.""" return torch.unravel_index(indices, shape) @implement_for("torch", None, "2.3") def tree_iter(pytree): """A version-compatible wrapper around tree_iter.""" flat_tree, _ = torch.utils._pytree.tree_flatten(pytree) yield from flat_tree @implement_for("torch", "2.3", "2.4") def tree_iter(pytree): # noqa: F811 """A version-compatible wrapper around tree_iter.""" yield from torch.utils._pytree.tree_leaves(pytree) @implement_for("torch", "2.4") def tree_iter(pytree): # noqa: F811 """A version-compatible wrapper around tree_iter.""" yield from torch.utils._pytree.tree_iter(pytree)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources