# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import abc
import math
import warnings
from import Iterable
from copy import deepcopy
from dataclasses import dataclass
from functools import wraps
from textwrap import indent
from typing import (

import numpy as np
import torch
from tensordict import (
from tensordict.utils import _getitem_batch_size, NestedKey

from torchrl._utils import get_binary_env_var

DEVICE_TYPING = Union[torch.device, str, int]

INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List]


# By default, we do not check that an obs is in the domain. THis should be done when validating the env beforehand
_CHECK_SPEC_ENCODE = get_binary_env_var("CHECK_SPEC_ENCODE")

_DEFAULT_SHAPE = torch.Size((1,))

DEVICE_ERR_MSG = "device of empty CompositeSpec is not defined."
NOT_IMPLEMENTED_ERROR = NotImplementedError(
    "method is not currently implemented."
    " If you are interested in this feature please submit"
    " an issue at"

NO_DEFAULT = object()

def _default_dtype_and_device(
    dtype: Union[None, torch.dtype],
    device: Union[None, str, int, torch.device],
    allow_none_device: bool = False,
) -> Tuple[torch.dtype, torch.device | None]:
    if dtype is None:
        dtype = torch.get_default_dtype()
    if device is not None:
        device = torch.device(device)
    elif not allow_none_device:
        device = torch.zeros(()).device
    return dtype, device

def _validate_idx(shape: list[int], idx: int, axis: int = 0):
    """Raise an IndexError if idx is out of bounds for shape[axis].

        shape (list[int]): Input shape
        idx (int): Index, may be negative
        axis (int): Shape axis to check
    if idx >= shape[axis] or idx < 0 and -idx > shape[axis]:
        raise IndexError(
            f"index {idx} is out of bounds for axis {axis} with size {shape[axis]}"

def _validate_iterable(
    idx: Iterable[Any], expected_type: type, iterable_classname: str
    """Raise an IndexError if the iterable contains a type different from the expected type or Iterable.

        idx (Iterable[Any]): Iterable, may contain nested iterables
        expected_type (type): Required item type in the Iterable (e.g. int)
        iterable_classname (str): Iterable type as a string (e.g. 'List'). Logging purpose only.
    for item in idx:
        if isinstance(item, Iterable):
            _validate_iterable(item, expected_type, iterable_classname)
            if not isinstance(item, expected_type):
                raise IndexError(
                    f"{iterable_classname} indexing expects {expected_type} indices"

def _slice_indexing(shape: list[int], idx: slice) -> List[int]:
    """Given an input shape and a slice index, returns the new indexed shape.

        shape (list[int]): Input shape
        idx (slice): Index
        Indexed shape
        >>> _slice_indexing([3, 4], slice(None, 2))
        [2, 4]
        >>> list(torch.rand(3, 4)[:2].shape)
        [2, 4]
    if idx.step == 0:
        raise ValueError("slice step cannot be zero")
    # Slicing an empty shape returns the shape
    if len(shape) == 0:
        return shape

    if idx.start is None:
        start = 0
        start = idx.start if idx.start >= 0 else max(shape[0] + idx.start, 0)

    if idx.stop is None:
        stop = shape[0]
        stop = idx.stop if idx.stop >= 0 else max(shape[0] + idx.stop, 0)

    step = 1 if idx.step is None else idx.step
    if step > 0:
        if start >= stop:
            n_items = 0
            stop = min(stop, shape[0])
            n_items = math.ceil((stop - start) / step)
        if start <= stop:
            n_items = 0
            start = min(start, shape[0] - 1)
            n_items = math.ceil((stop - start) / step)
    return [n_items] + shape[1:]

def _shape_indexing(
    shape: Union[list[int], torch.Size, tuple[int]], idx: SHAPE_INDEX_TYPING
) -> List[int]:
    """Given an input shape and an index, returns the size of the resulting indexed spec.

    This function includes indexing checks and may raise IndexErrors.

        shape (list[int], torch.Size, tuple[int): Input shape
        idx (SHAPE_INDEX_TYPING): Index
        Shape of the resulting spec
        >>> idx = (2, ..., None)
        >>> DiscreteTensorSpec(2, shape=(3, 4))[idx].shape
        torch.Size([4, 1])
        >>> _shape_indexing([3, 4], idx)
        torch.Size([4, 1])
    if not isinstance(shape, list):
        shape = list(shape)

    if idx is Ellipsis or (
        isinstance(idx, slice) and (idx.step is idx.start is idx.stop is None)
        return shape

    if idx is None:
        return [1] + shape

    if len(shape) == 0 and (
        isinstance(idx, int)
        or isinstance(idx, range)
        or isinstance(idx, list)
        and len(idx) > 0
        raise IndexError(
            f"cannot use integer indices on 0-dimensional shape. `{idx}` received"

    if isinstance(idx, int):
        _validate_idx(shape, idx)
        return shape[1:]

    if isinstance(idx, range):
        if len(idx) > 0 and (idx.start >= shape[0] or idx.stop > shape[0]):
            raise IndexError(f"index out of bounds for axis 0 with size {shape[0]}")
        return [len(idx)] + shape[1:]

    if isinstance(idx, slice):
        return _slice_indexing(shape, idx)

    if isinstance(idx, tuple):
        # Supports int, None, slice and ellipsis indices
        # Index on the current shape dimension
        shape_idx = 0
        none_dims = 0
        ellipsis = False
        prev_is_list = False
        shape_len = len(shape)
        for item_idx, item in enumerate(idx):
            if item is None:
                shape = shape[:shape_idx] + [1] + shape[shape_idx:]
                shape_idx += 1
                none_dims += 1
            elif isinstance(item, int):
                _validate_idx(shape, item, shape_idx)
                del shape[shape_idx]
            elif isinstance(item, slice):
                shape[shape_idx] = _slice_indexing([shape[shape_idx]], item)[0]
                shape_idx += 1
            elif item is Ellipsis:
                if ellipsis:
                    raise IndexError("an index can only have a single ellipsis (`...`)")
                # Move to the end of the shape, subtracted by the number of future indices impacting the dimensions (i.e. all except None and ...)
                shape_idx = len(shape) - len(
                    [i for i in idx[item_idx + 1 :] if not (i is None or i is Ellipsis)]
                ellipsis = True
            elif any(
                isinstance(item, _type)
                for _type in [list, tuple, range, np.ndarray, torch.Tensor]
                while isinstance(idx, tuple) and len(idx) == 1:
                    idx = idx[0]

                # Nested tuples are handled as a list. Numpy behavior
                if isinstance(item, tuple):
                    item = list(item)

                if prev_is_list and isinstance(item, list):
                    del shape[shape_idx]

                if isinstance(item, list):
                    prev_is_list = True

                if shape_idx >= len(shape):
                    raise IndexError("Raise IndexError: too many indices for array")

                res = _shape_indexing([shape[shape_idx]], item)
                shape = shape[:shape_idx] + res + shape[shape_idx + 1 :]
                shape_idx += len(res)
                raise IndexError(
                    f"tuple indexing only supports integers, ranges, slices (`:`), ellipsis (`...`), new axis (`None`), tuples, list, tensor and ndarray indices. {str(type(idx))} received"

        if len(idx) - none_dims - int(ellipsis) > shape_len:
            raise IndexError(
                f"shape is {shape_len}-dimensional, but {len(idx) - none_dims - int(ellipsis)} dimensions were indexed"
        return shape

    if isinstance(idx, list):
        # int indexing only
        _validate_iterable(idx, int, "list")
        for item in np.array(idx).reshape(-1):
            _validate_idx(shape, item, 0)
        return list(np.array(idx).shape) + shape[1:]

    if isinstance(idx, np.ndarray) or isinstance(idx, torch.Tensor):
        # Out of bounds check
        for item in idx.reshape(-1):
            _validate_idx(shape, item)
        return list(_getitem_batch_size(shape, idx))

class invertible_dict(dict):
    """An invertible dictionary.

        >>> my_dict = invertible_dict(a=3, b=2)
        >>> inv_dict = my_dict.invert()
        >>> assert {2, 3} == set(inv_dict.keys())

    def __init__(self, *args, inv_dict=None, **kwargs):
        if inv_dict is None:
            inv_dict = {}
        super().__init__(*args, **kwargs)
        self.inv_dict = inv_dict

    def __setitem__(self, k, v):
        if v in self.inv_dict or k in self:
            raise Exception("overwriting in invertible_dict is not permitted")
        self.inv_dict[v] = k
        return super().__setitem__(k, v)

    def update(self, d):
        raise NotImplementedError

    def invert(self):
        d = invertible_dict()
        for k, value in self.items():
            d[value] = k
        return d

    def inverse(self):
        return self.inv_dict

class Box:
    """A box of values."""

    def __iter__(self):
        raise NotImplementedError

    def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox:
        raise NotImplementedError

    def __repr__(self):
        return f"{self.__class__.__name__}()"

    def clone(self) -> DiscreteBox:
        return deepcopy(self)

class ContinuousBox(Box):
    """A continuous box of values, in between a minimum (self.low) and a maximum (self.high)."""

    _low: torch.Tensor
    _high: torch.Tensor
    device: torch.device | None = None

    # We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used.
    def low(self):

    def high(self):

    def low(self, value):
        self.device = value.device
        self._low = value.cpu()

    def high(self, value):
        self.device = value.device
        self._high = value.cpu()

    def low(self, value):
        self.device = value.device
        self._low = value.cpu()

    def high(self, value):
        self.device = value.device
        self._high = value.cpu()

    def __post_init__(self):
        self.low = self.low.clone()
        self.high = self.high.clone()

    def __iter__(self):
        yield self.low
        yield self.high

    def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox:
        return self.__class__(,

    def clone(self) -> ContinuousBox:
        return self.__class__(self.low.clone(), self.high.clone())

    def __repr__(self):
        min_str = indent(
            f"\nlow=Tensor(shape={self.low.shape}, device={self.low.device}, dtype={self.low.dtype}, contiguous={self.high.is_contiguous()})",
            " " * 4,
        max_str = indent(
            f"\nhigh=Tensor(shape={self.high.shape}, device={self.high.device}, dtype={self.high.dtype}, contiguous={self.high.is_contiguous()})",
            " " * 4,
        return f"{self.__class__.__name__}({min_str},{max_str})"

    def __eq__(self, other):
        if other is None:

            minval, maxval = _minmax_dtype(self.low.dtype)
            minval = torch.as_tensor(minval).to(self.low.device, self.low.dtype)
            maxval = torch.as_tensor(maxval).to(self.low.device, self.low.dtype)
            if (
                torch.isclose(self.low, minval).all()
                and torch.isclose(self.high, maxval).all()
                return True
            if (
                not torch.isfinite(self.low).any()
                and not torch.isfinite(self.high).any()
                return True
            return False
        return (
            type(self) == type(other)
            and self.low.dtype == other.low.dtype
            and self.high.dtype == other.high.dtype
            and self.device == other.device
            and torch.isclose(self.low, other.low).all()
            and torch.isclose(self.high, other.high).all()

class DiscreteBox(Box):
    """A box of discrete values."""

    n: int
    register = invertible_dict()

    def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> DiscreteBox:
        return deepcopy(self)

    def __repr__(self):
        return f"{self.__class__.__name__}(n={self.n})"

class BoxList(Box):
    """A box of discrete values."""

    boxes: List

    def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> BoxList:
        return BoxList([ for box in self.boxes])

    def __iter__(self):
        for elt in self.boxes:
            yield elt

    def __repr__(self):
        return f"{self.__class__.__name__}(boxes={self.boxes})"

    def __len__(self):
        return len(self.boxes)

    def from_nvec(nvec: torch.Tensor):
        if nvec.ndim == 0:
            return DiscreteBox(nvec.item())
            return BoxList([BoxList.from_nvec(n) for n in nvec.unbind(-1)])

class BinaryBox(Box):
    """A box of n binary values."""

    n: int

    def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox:
        return deepcopy(self)

    def __repr__(self):
        return f"{self.__class__.__name__}(n={self.n})"

[docs]@dataclass(repr=False) class TensorSpec: """Parent class of the tensor meta-data containers for observation, actions and rewards. Args: shape (torch.Size): size of the tensor space (Box): Box instance describing what kind of values can be expected device (torch.device): device of the tensor dtype (torch.dtype): dtype of the tensor """ shape: torch.Size space: Union[None, Box] device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" SPEC_HANDLED_FUNCTIONS = {}
[docs] @classmethod def implements_for_spec(cls, torch_function: Callable) -> Callable: """Register a torch function override for TensorSpec.""" @wraps(torch_function) def decorator(func): cls.SPEC_HANDLED_FUNCTIONS[torch_function] = func return func return decorator
[docs] def clear_device_(self): """A no-op for all leaf specs (which must have a device).""" return self
[docs] def encode( self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False ) -> torch.Tensor: """Encodes a value given the specified spec, and return the corresponding tensor. Args: val (np.ndarray or torch.Tensor): value to be encoded as tensor. Keyword Args: ignore_device (bool, optional): if ``True``, the spec device will be ignored. This is used to group tensor casting within a call to ``TensorDict(..., device="cuda")`` which is faster. Returns: torch.Tensor matching the required tensor specs. """ if not isinstance(val, torch.Tensor): if isinstance(val, list): if len(val) == 1: # gym used to return lists of images since 0.26.0 # We convert these lists in np.array or take the first element # if there is just one. # See val = val[0] else: val = np.array(val) if isinstance(val, np.ndarray) and not all( stride > 0 for stride in val.strides ): val = val.copy() if not ignore_device: val = torch.as_tensor(val, device=self.device, dtype=self.dtype) else: val = torch.as_tensor(val, dtype=self.dtype) if val.shape != self.shape: # if val.shape[-len(self.shape) :] != self.shape: # option 1: add a singleton dim at the end if val.shape == self.shape and self.shape[-1] == 1: val = val.unsqueeze(-1) else: try: val = val.reshape(self.shape) except Exception as err: raise RuntimeError( f"Shape mismatch: the value has shape {val.shape} which " f"is incompatible with the spec shape {self.shape}." ) from err if _CHECK_SPEC_ENCODE: self.assert_is_in(val) return val
def __ne__(self, other): return not (self == other) def __setattr__(self, key, value): if key == "shape": value = torch.Size(value) super().__setattr__(key, value)
[docs] def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: """Returns the np.ndarray correspondent of an input tensor. Args: val (torch.Tensor): tensor to be transformed_in to numpy. safe (bool): boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable. Returns: a np.ndarray """ if safe is None: safe = _CHECK_SPEC_ENCODE if safe: self.assert_is_in(val) return val.detach().cpu().numpy()
@property def ndim(self): return self.ndimension() def ndimension(self): return len(self.shape)
[docs] @abc.abstractmethod def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: """Indexes the input tensor. Args: index (int, torch.Tensor, slice or list): index of the tensor tensor_to_index: tensor to be indexed Returns: indexed tensor """ raise NotImplementedError
[docs] @abc.abstractmethod def expand(self, *shape): """Returns a new Spec with the extended shape. Args: *shape (tuple or iterable of int): the new shape of the Spec. Must comply with the current shape: its length must be at least as long as the current shape length, and its last values must be complient too; ie they can only differ from it if the current dimension is a singleton. """ raise NotImplementedError
[docs] def squeeze(self, dim: int | None = None): """Returns a new Spec with all the dimensions of size ``1`` removed. When ``dim`` is given, a squeeze operation is done only in that dimension. Args: dim (int or None): the dimension to apply the squeeze operation to """ shape = _squeezed_shape(self.shape, dim) if shape is None: return self return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
def unsqueeze(self, dim: int): shape = _unsqueezed_shape(self.shape, dim) return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
[docs] def reshape(self, *shape): """Reshapes a tensorspec. Check :func:`~torch.reshape` for more information on this method. """ if len(shape) == 1 and not isinstance(shape[0], int): return self.reshape(*shape[0]) return self._reshape(shape)
view = reshape @abc.abstractmethod def _reshape(self, shape): ...
[docs] def unflatten(self, dim, sizes): """Unflattens a tensorspec. Check :func:`~torch.unflatten` for more information on this method. """ return self._unflatten(dim, sizes)
def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self._reshape(shape)
[docs] def flatten(self, start_dim, end_dim): """Flattens a tensorspec. Check :func:`~torch.flatten` for more information on this method. """ return self._flatten(start_dim, end_dim)
def _flatten(self, start_dim, end_dim): shape = torch.zeros(self.shape, device="meta").flatten(start_dim, end_dim).shape return self._reshape(shape) @abc.abstractmethod def _project(self, val: torch.Tensor) -> torch.Tensor: raise NotImplementedError(type(self))
[docs] @abc.abstractmethod def is_in(self, val: torch.Tensor) -> bool: """If the value :obj:`val` is in the box defined by the TensorSpec, returns True, otherwise False. Args: val (torch.Tensor): value to be checked Returns: boolean indicating if values belongs to the TensorSpec box """ raise NotImplementedError
[docs] def project(self, val: torch.Tensor) -> torch.Tensor: """If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic. Args: val (torch.Tensor): tensor to be mapped to the box. Returns: a torch.Tensor belonging to the TensorSpec box. """ if not self.is_in(val): return self._project(val) return val
[docs] def assert_is_in(self, value: torch.Tensor) -> None: """Asserts whether a tensor belongs to the box, and raises an exception otherwise. Args: value (torch.Tensor): value to be checked. """ if not self.is_in(value): raise AssertionError( f"Encoding failed because value is not in space. " f"Consider calling project(val) first. value was = {value} " f"and spec was {self}." )
[docs] def type_check(self, value: torch.Tensor, key: str = None) -> None: """Checks the input value dtype against the TensorSpec dtype and raises an exception if they don't match. Args: value (torch.Tensor): tensor whose dtype has to be checked key (str, optional): if the TensorSpec has keys, the value dtype will be checked against the spec pointed by the indicated key. """ if value.dtype is not self.dtype: raise TypeError( f"value.dtype={value.dtype} but" f" {self.__class__.__name__}.dtype={self.dtype}" )
[docs] @abc.abstractmethod def rand(self, shape=None) -> torch.Tensor: """Returns a random tensor in the box. The sampling will be uniform unless the box is unbounded. Args: shape (torch.Size): shape of the random tensor Returns: a random tensor sampled in the TensorSpec box. """ raise NotImplementedError
[docs] def zero(self, shape=None) -> torch.Tensor: """Returns a zero-filled tensor in the box. Args: shape (torch.Size): shape of the zero-tensor Returns: a zero-filled tensor sampled in the TensorSpec box. """ if shape is None: shape = torch.Size([]) return torch.zeros((*shape, *self.shape), dtype=self.dtype, device=self.device)
@abc.abstractmethod def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec": raise NotImplementedError def cpu(self): return"cpu") def cuda(self, device=None): if device is None: return"cuda") return"cuda:{device}") @abc.abstractmethod def clone(self) -> "TensorSpec": raise NotImplementedError def __repr__(self): shape_str = indent("shape=" + str(self.shape), " " * 4) space_str = indent("space=" + str(, " " * 4) device_str = indent("device=" + str(self.device), " " * 4) dtype_str = indent("dtype=" + str(self.dtype), " " * 4) domain_str = indent("domain=" + str(self.domain), " " * 4) sub_string = ",\n".join( [shape_str, space_str, device_str, dtype_str, domain_str] ) string = f"{self.__class__.__name__}(\n{sub_string})" return string @classmethod def __torch_function__( cls, func: Callable, types, args: Tuple = (), kwargs: Optional[dict] = None, ) -> Callable: if kwargs is None: kwargs = {} if func not in cls.SPEC_HANDLED_FUNCTIONS or not all( issubclass(t, (TensorSpec,)) for t in types ): return NotImplemented( f"func {func} for spec {cls} with handles {cls.SPEC_HANDLED_FUNCTIONS}" ) return cls.SPEC_HANDLED_FUNCTIONS[func](*args, **kwargs) def unbind(self, dim: int): raise NotImplementedError
T = TypeVar("T") class _LazyStackedMixin(Generic[T]): def __init__(self, *specs: tuple[T, ...], dim: int) -> None: self._specs = list(specs) self.dim = dim if self.dim < 0: self.dim = len(self.shape) + self.dim def clear_device_(self): """Clears the device of the CompositeSpec.""" for spec in self._specs: spec.clear_device_() return self def __getitem__(self, item): is_key = isinstance(item, str) or ( isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) ) if is_key: return torch.stack( [composite_spec[item] for composite_spec in self._specs], dim=self.dim ) elif isinstance(item, tuple): # quick check that the index is along the stacked dim # case 1: index is a tuple, and the first arg is an ellipsis. Then dim must be the last dim of all composite_specs if item[0] is Ellipsis: if len(item) == 1: return self elif self.dim == len(self.shape) - 1 and len(item) == 2: # we can return return self._specs[item[1]] elif len(item) > 2: # check that there is only one non-slice index assigned = False dim_idx = self.dim for i, _item in enumerate(item[1:]): if ( isinstance(_item, slice) and not ( _item.start is None and _item.stop is None and _item.step is None ) ) or not isinstance(_item, slice): if assigned: raise RuntimeError( "Found more than one meaningful index in a stacked composite spec." ) item = _item dim_idx = i + 1 assigned = True if not assigned: return self if dim_idx != self.dim: raise RuntimeError( f"Indexing occured along dimension {dim_idx} but stacking was done along dim {self.dim}." ) out = self._specs[item] if isinstance(out, TensorSpec): return out return torch.stack(list(out), 0) else: raise IndexError( f"Indexing a {self.__class__.__name__} with [..., idx] is only permitted if the stack dimension is the last dimension. " f"Got self.dim={self.dim} and self.shape={self.shape}." ) elif len(item) >= 2 and item[-1] is Ellipsis: return self[item[:-1]] elif any(_item is Ellipsis for _item in item): raise IndexError("Cannot index along multiple dimensions.") # Ellipsis is now ruled out elif any(_item is None for _item in item): raise IndexError( f"Cannot index a {self.__class__.__name__} with None values" ) # Must be an index with slices then else: for i, _item in enumerate(item): if i == self.dim: out = self._specs[_item] if isinstance(out, TensorSpec): return out return torch.stack(list(out), 0) elif isinstance(_item, slice): # then the slice must be trivial if not (_item.step is _item.start is _item.stop is None): raise IndexError( f"Got a non-trivial index at dim {i} when only the dim {self.dim} could be indexed." ) else: return self else: if not self.dim == 0: raise IndexError( f"Trying to index a {self.__class__.__name__} along dimension 0 when the stack dimension is {self.dim}." ) out = self._specs[item] if isinstance(out, TensorSpec): return out return torch.stack(list(out), 0) def clone(self) -> T: return torch.stack([spec.clone() for spec in self._specs], self.stack_dim) @property def stack_dim(self): return self.dim def zero(self, shape=None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: dim = self.dim return LazyStackedTensorDict.maybe_dense_stack( [ for spec in self._specs], dim ) def rand(self, shape=None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: dim = self.dim return LazyStackedTensorDict.maybe_dense_stack( [spec.rand(shape) for spec in self._specs], dim ) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T: if dest is None: return self return torch.stack([ for spec in self._specs], self.dim) def unbind(self, dim: int): if dim == self.stack_dim: return self._specs shape = self.shape if dim < 0 or dim > self.ndim - 1 or shape[dim] == -1: raise ValueError( f"Provided dim {dim} is not valid for unbinding shape {shape}" ) else: raise ValueError( f"A {type(self)} instance can only be unbound along its stack dimension. Expected {self.stack_dim}, received {dim} instead." ) def unsqueeze(self, dim: int): if dim < 0: new_dim = dim + len(self.shape) + 1 else: new_dim = dim if new_dim > len(self.shape) or new_dim < 0: raise ValueError(f"Cannot unsqueeze along dim {dim}.") if new_dim > self.dim: # unsqueeze 2, stack is on 1 => unsqueeze 1, stack along 1 new_stack_dim = self.dim new_dim = new_dim - 1 else: # unsqueeze 0, stack is on 1 => unsqueeze 0, stack on 1 new_stack_dim = self.dim + 1 return torch.stack( [spec.unsqueeze(new_dim) for spec in self._specs], dim=new_stack_dim ) def squeeze(self, dim: int = None): if dim is None: size = self.shape if len(size) == 1 or size.count(1) == 0: return self first_singleton_dim = size.index(1) squeezed_dict = self.squeeze(first_singleton_dim) return squeezed_dict.squeeze(dim=None) if dim < 0: new_dim = self.ndim + dim else: new_dim = dim if self.shape and (new_dim >= self.ndim or new_dim < 0): raise RuntimeError( f"squeezing is allowed for dims comprised between 0 and " f"spec.ndim only. Got dim={dim} and shape" f"={self.shape}." ) if new_dim >= self.ndim or self.shape[new_dim] != 1: return self if new_dim == self.dim: return self._specs[0] if new_dim > self.dim: # squeeze 2, stack is on 1 => squeeze 1, stack along 1 new_stack_dim = self.dim new_dim = new_dim - 1 else: # squeeze 0, stack is on 1 => squeeze 0, stack on 1 new_stack_dim = self.dim - 1 return torch.stack( [spec.squeeze(new_dim) for spec in self._specs], dim=new_stack_dim )
[docs]class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): """A lazy representation of a stack of tensor specs. Stacks tensor-specs together along one dimension. When random samples are drawn, a stack of samples is returned if possible. If not, an error is thrown. Indexing is allowed but only along the stack dimension. This class is aimed to be used in multi-task and multi-agent settings, where heterogeneous specs may occur (same semantic but different shape). """ def __eq__(self, other): if not isinstance(other, LazyStackedTensorSpec): return False if self.device != other.device: return False if len(self._specs) != len(other._specs): return False for _spec1, _spec2 in zip(self._specs, other._specs): if _spec1 != _spec2: return False return True def __len__(self): return self.shape[0]
[docs] def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict: if safe is None: safe = _CHECK_SPEC_ENCODE if safe: if val.shape[self.dim] != len(self._specs): raise ValueError( "Size of LazyStackedTensorSpec and val differ along the stacking " "dimension" ) for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): spec.assert_is_in(v) return val.detach().cpu().numpy()
def __repr__(self): shape_str = "shape=" + str(self.shape) device_str = "device=" + str(self.device) dtype_str = "dtype=" + str(self.dtype) domain_str = "domain=" + str(self._specs[0].domain) sub_string = ", ".join([shape_str, device_str, dtype_str, domain_str]) string = f"LazyStacked{self._specs[0].__class__.__name__}(\n {sub_string})" return string @property def device(self) -> DEVICE_TYPING: return self._specs[0].device @property def ndim(self): return self.ndimension() def ndimension(self): return len(self.shape) @property def shape(self): first_shape = self._specs[0].shape shape = [] for i in range(len(first_shape)): homo_dim = True for spec in self._specs: if spec.shape[i] != first_shape[i]: homo_dim = False break shape.append(first_shape[i] if homo_dim else -1) dim = self.dim if dim < 0: dim = len(shape) + dim + 1 shape.insert(dim, len(self._specs)) return torch.Size(shape)
[docs] def expand(self, *shape): if len(shape) == 1 and not isinstance(shape[0], (int,)): return self.expand(*shape[0]) expand_shape = shape[: -len(self.shape)] existing_shape = self.shape shape_check = shape[-len(self.shape) :] for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): if size1 != size2 and size1 != 1: raise RuntimeError( f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" ) elif size1 != size2 and size1 == 1 and _i == self.dim: # if we're expanding along the stack dim we just need to clone the existing spec return torch.stack( [self._specs[0].clone() for _ in range(size2)], self.dim ).expand(*shape) if _i != len(self.shape) - 1: raise RuntimeError( f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." ) # remove the stack dim from the expanded shape, which we know to match shape_check = [s for i, s in enumerate(shape_check) if i != self.dim] specs = [] for spec in self._specs: spec_shape = [] for dim_check, spec_dim in zip(shape_check, spec.shape): spec_shape.append(dim_check if dim_check != -1 else spec_dim) unstack_shape = list(expand_shape) + list(spec_shape) specs.append(spec.expand(unstack_shape)) return torch.stack( specs, self.dim + len(expand_shape), )
[docs] def type_check(self, value: torch.Tensor, key: str = None) -> None: raise NOT_IMPLEMENTED_ERROR
[docs] def is_in(self, val) -> bool: raise NOT_IMPLEMENTED_ERROR
@property def space(self): raise NOT_IMPLEMENTED_ERROR def _project(self, val: TensorDictBase) -> TensorDictBase: raise NOT_IMPLEMENTED_ERROR
[docs] def encode( self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False ) -> torch.Tensor: raise NOT_IMPLEMENTED_ERROR
[docs]@dataclass(repr=False) class OneHotDiscreteTensorSpec(TensorSpec): """A unidimensional, one-hot discrete tensor spec. By default, TorchRL assumes that categorical variables are encoded as one-hot encodings of the variable. This allows for simple indexing of tensors, e.g. >>> batch, size = 3, 4 >>> action_value = torch.arange(batch*size) >>> action_value = action_value.view(batch, size).to(torch.float) >>> action = (action_value == action_value.max(-1, ... keepdim=True)[0]).to(torch.long) >>> chosen_action_value = (action * action_value).sum(-1) >>> print(chosen_action_value) tensor([ 3., 7., 11.]) The last dimension of the shape (variable domain) cannot be indexed. Args: n (int): number of possible outcomes. shape (torch.Size, optional): total shape of the sampled tensors. If provided, the last dimension must match n. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. user_register (bool): experimental feature. If True, every integer will be mapped onto a binary vector in the order in which they appear. This feature is designed for environment with no a-priori definition of the number of possible outcomes (e.g. discrete outcomes are sampled from an arbitrary set, whose elements will be mapped in a register to a series of unique one-hot binary vectors). """ shape: torch.Size space: DiscreteBox device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" # SPEC_HANDLED_FUNCTIONS = {} def __init__( self, n: int, shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.bool, use_register: bool = False, mask: torch.Tensor | None = None, ): dtype, device = _default_dtype_and_device(dtype, device) self.use_register = use_register space = DiscreteBox(n) if shape is None: shape = torch.Size((space.n,)) else: shape = torch.Size(shape) if not len(shape) or shape[-1] != space.n: raise ValueError( f"The last value of the shape must match n for transform of type {self.__class__}. " f"Got n={space.n} and shape={shape}." ) super().__init__( shape=shape, space=space, device=device, dtype=dtype, domain="discrete" ) self.update_mask(mask) @property def n(self): return def update_mask(self, mask): if mask is not None: try: mask = mask.expand(self.shape) except RuntimeError as err: raise RuntimeError("Cannot expand mask to the desired shape.") from err if mask.dtype != torch.bool: raise ValueError("Only boolean masks are accepted.") self.mask = mask def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if dest is None: return self if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device else: dest_dtype = self.dtype dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self return self.__class__(, shape=self.shape, device=dest_device, dtype=dest_dtype, use_register=self.use_register, if self.mask is not None else None, ) def clone(self) -> OneHotDiscreteTensorSpec: return self.__class__(, shape=self.shape, device=self.device, dtype=self.dtype, use_register=self.use_register, mask=self.mask.clone() if self.mask is not None else None, )
[docs] def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( f"{self.__class__.__name__}.expand does not support negative shapes." ) if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): raise ValueError( f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) mask = self.mask if mask is not None: mask = mask.expand(shape) return self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, mask=mask, )
def _reshape(self, shape): mask = self.mask if mask is not None: mask = mask.reshape(shape) return self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, mask=mask, ) def _unflatten(self, dim, sizes): mask = self.mask if mask is not None: mask = mask.unflatten(dim, sizes) shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, mask=mask, )
[docs] def squeeze(self, dim=None): if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _squeezed_shape(self.shape, dim) if shape is None: return self mask = self.mask if mask is not None: mask = mask.reshape(shape) return self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, use_register=self.use_register, mask=mask, )
def unsqueeze(self, dim: int): if dim in (len(self.shape), -1): raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _unsqueezed_shape(self.shape, dim) mask = self.mask if mask is not None: mask = mask.reshape(shape) return self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, use_register=self.use_register, mask=mask, ) def unbind(self, dim: int): if dim in (len(self.shape), -1): raise ValueError(f"Final dimension of {type(self)} must remain unchanged") orig_dim = dim if dim < 0: dim = len(self.shape) + dim if dim < 0: raise ValueError( f"Cannot unbind along dim {orig_dim} with shape {self.shape}." ) shape = tuple(s for i, s in enumerate(self.shape) if i != dim) mask = self.mask if mask is not None: mask = mask.unbind(dim) else: mask = (None,) * self.shape[dim] return tuple( self.__class__( n=shape[-1], shape=shape, device=self.device, dtype=self.dtype, use_register=self.use_register, mask=mask[i], ) for i in range(self.shape[dim]) )
[docs] def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] else: shape = torch.Size([*shape, *self.shape[:-1]]) mask = self.mask if mask is None: n = m = torch.randint(n, shape, device=self.device) else: mask = mask.expand(*shape, mask.shape[-1]) if mask.ndim > 2: mask_flat = torch.flatten(mask, 0, -2) else: mask_flat = mask shape_out = mask.shape[:-1] m = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) out = torch.nn.functional.one_hot(m, # torch.zeros((*shape,, device=self.device, dtype=self.dtype) # out.scatter_(-1, m, 1) return out
[docs] def encode( self, val: Union[np.ndarray, torch.Tensor], space: Optional[DiscreteBox] = None, *, ignore_device: bool = False, ) -> torch.Tensor: if not isinstance(val, torch.Tensor): if ignore_device: val = torch.as_tensor(val) else: val = torch.as_tensor(val, device=self.device) if space is None: space = if self.use_register: if val not in space.register: space.register[val] = len(space.register) val = space.register[val] if (val >= space.n).any(): raise AssertionError("Value must be less than action space.") val = torch.nn.functional.one_hot(val.long(), space.n).to(self.dtype) return val
[docs] def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: if safe is None: safe = _CHECK_SPEC_ENCODE if safe: if not isinstance(val, torch.Tensor): raise NotImplementedError self.assert_is_in(val) val = val.long().argmax(-1).cpu().numpy() if self.use_register: inv_reg = vals = [] for _v in val.view(-1): vals.append(inv_reg[int(_v)]) return np.array(vals).reshape(tuple(val.shape)) return val
[docs] def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: if not isinstance(index, torch.Tensor): raise ValueError( f"Only tensors are allowed for indexing using " f"{self.__class__.__name__}.index(...)" ) index = index.nonzero().squeeze() index = index.expand((*tensor_to_index.shape[:-1], index.shape[-1])) return tensor_to_index.gather(-1, index)
def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index. The last dimension of the spec corresponding to the variable domain cannot be indexed. """ indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__(, shape=torch.Size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, use_register=self.use_register, mask=self.mask[idx] if self.mask is not None else None, ) def _project(self, val: torch.Tensor) -> torch.Tensor: if self.mask is None: out = torch.multinomial(, 1).squeeze(-1) out = torch.nn.functional.one_hot(out, return out shape = self.mask.shape shape = torch.broadcast_shapes(shape, val.shape) mask_expand = self.mask.expand(shape) gathered = mask_expand & val oob = ~gathered.any(-1) new_val = torch.multinomial(mask_expand[oob].float(), 1) val = val.clone() val[oob] = 0 val[oob] = torch.scatter(val[oob], -1, new_val, 1) return val
[docs] def is_in(self, val: torch.Tensor) -> bool: if self.mask is None: return (val.sum(-1) == 1).all() shape = self.mask.shape shape = torch.broadcast_shapes(shape, val.shape) mask_expand = self.mask.expand(shape) gathered = mask_expand & val return gathered.any(-1).all()
def __eq__(self, other): if not hasattr(other, "mask"): return False mask_equal = (self.mask is None and other.mask is None) or ( isinstance(self.mask, torch.Tensor) and isinstance(other.mask, torch.Tensor) and (self.mask.shape == other.mask.shape) and (self.mask == other.mask).all() ) return ( type(self) == type(other) and self.shape == other.shape and == and self.device == other.device and self.dtype == other.dtype and self.domain == other.domain and self.use_register == other.use_register and mask_equal )
[docs] def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: """Converts a given one-hot tensor in categorical format. Args: val (torch.Tensor, optional): One-hot tensor to convert in categorical format. safe (bool): boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable. Returns: The categorical tensor. """ if safe is None: safe = _CHECK_SPEC_ENCODE if safe: self.assert_is_in(val) return val.long().argmax(-1)
[docs] def to_categorical_spec(self) -> DiscreteTensorSpec: """Converts the spec to the equivalent categorical spec.""" return DiscreteTensorSpec(, device=self.device, shape=self.shape[:-1], mask=self.mask, )
[docs]@dataclass(repr=False) class BoundedTensorSpec(TensorSpec): """A bounded continuous tensor spec. Args: low (np.ndarray, torch.Tensor or number): lower bound of the box. high (np.ndarray, torch.Tensor or number): upper bound of the box. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. """ # SPEC_HANDLED_FUNCTIONS = {} CONFLICTING_KWARGS = ( "The keyword arguments {} and {} conflict. Only one of these can be passed." ) def __init__( self, low: Union[float, torch.Tensor, np.ndarray] = None, high: Union[float, torch.Tensor, np.ndarray] = None, shape: Optional[Union[torch.Size, int]] = None, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[torch.dtype, str]] = None, **kwargs, ): if "maximum" in kwargs: if high is not None: raise TypeError(self.CONFLICTING_KWARGS.format("high", "maximum")) high = kwargs.pop("maximum") warnings.warn(self.DEPRECATED_KWARGS, category=DeprecationWarning) if "minimum" in kwargs: if low is not None: raise TypeError(self.CONFLICTING_KWARGS.format("low", "minimum")) low = kwargs.pop("minimum") warnings.warn(self.DEPRECATED_KWARGS, category=DeprecationWarning) domain = kwargs.pop("domain", "continuous") if len(kwargs): raise TypeError(f"Got unrecognised kwargs {tuple(kwargs.keys())}.") dtype, device = _default_dtype_and_device(dtype, device) if dtype is None: dtype = torch.get_default_dtype() if not isinstance(low, torch.Tensor): low = torch.tensor(low, dtype=dtype, device=device) if not isinstance(high, torch.Tensor): high = torch.tensor(high, dtype=dtype, device=device) if high.device != device: high = if low.device != device: low = if dtype is not None and low.dtype is not dtype: low = if dtype is not None and high.dtype is not dtype: high = err_msg = ( "BoundedTensorSpec requires the shape to be explicitely (via " "the shape argument) or implicitely defined (via either the " "minimum or the maximum or both). If the maximum and/or the " "minimum have a non-singleton shape, they must match the " "provided shape if this one is set explicitely." ) if shape is not None and not isinstance(shape, torch.Size): if isinstance(shape, int): shape = torch.Size([shape]) else: shape = torch.Size(list(shape)) if high.ndimension(): if shape is not None and shape != high.shape: raise RuntimeError(err_msg) shape = high.shape low = low.expand(shape).clone() elif low.ndimension(): if shape is not None and shape != low.shape: raise RuntimeError(err_msg) shape = low.shape high = high.expand(shape).clone() elif shape is None: raise RuntimeError(err_msg) else: low = low.expand(shape).clone() high = high.expand(shape).clone() if low.numel() > high.numel(): high = high.expand_as(low).clone() elif high.numel() > low.numel(): low = low.expand_as(high).clone() if shape is None: shape = low.shape else: if isinstance(shape, float): shape = torch.Size([shape]) elif not isinstance(shape, torch.Size): shape = torch.Size(shape) shape_err_msg = f"low and shape mismatch, got {low.shape} and {shape}" if len(low.shape) != len(shape): raise RuntimeError(shape_err_msg) if not all(_s == _sa for _s, _sa in zip(shape, low.shape)): raise RuntimeError(shape_err_msg) self.shape = shape super().__init__( shape=shape, space=ContinuousBox(low, high, device=device), device=device, dtype=dtype, domain=domain, ) def __eq__(self, other): return ( type(other) == type(self) and self.device == other.device and self.shape == other.shape and == and self.dtype == other.dtype ) @property def low(self): return @property def high(self): return
[docs] def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( f"{self.__class__.__name__}.expand does not support negative shapes." ) if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): raise ValueError( f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) return self.__class__(,, shape=shape, device=self.device, dtype=self.dtype, )
def _reshape(self, shape): return self.__class__(,, shape=shape, device=self.device, dtype=self.dtype, ) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self.__class__(, sizes).clone(),, sizes).clone(), shape=shape, device=self.device, dtype=self.dtype, )
[docs] def squeeze(self, dim: int | None = None): shape = _squeezed_shape(self.shape, dim) if shape is None: return self if dim is None: low = high = else: low = high = return self.__class__( low=low, high=high, shape=shape, device=self.device, dtype=self.dtype, )
def unsqueeze(self, dim: int): shape = _unsqueezed_shape(self.shape, dim) return self.__class__(,, shape=shape, device=self.device, dtype=self.dtype, ) def unbind(self, dim: int): if dim in (len(self.shape), -1): raise ValueError(f"Final dimension of {type(self)} must remain unchanged") orig_dim = dim if dim < 0: dim = len(self.shape) + dim if dim < 0: raise ValueError( f"Cannot unbind along dim {orig_dim} with shape {self.shape}." ) shape = tuple(s for i, s in enumerate(self.shape) if i != dim) low = high = return tuple( self.__class__( low=low[i], high=high[i], shape=shape, device=self.device, dtype=self.dtype, ) for i in range(self.shape[dim]) )
[docs] def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) a, b = if self.dtype in (torch.float, torch.double, torch.half): shape = [*shape, *self.shape] out = ( torch.zeros(shape, dtype=self.dtype, device=self.device).uniform_() * (b - a) + a ) if (out > b).any(): out[out > b] = b.expand_as(out)[out > b] if (out < a).any(): out[out < a] = a.expand_as(out)[out < a] return out else: if == torch.bool: maxi = else: maxi = if == torch.bool: mini = else: mini = interval = maxi - mini r = torch.rand(torch.Size([*shape, *self.shape]), device=interval.device) r = interval * r r = + r r = return r
def _project(self, val: torch.Tensor) -> torch.Tensor: low = high = try: val = val.clamp_(low.item(), high.item()) except ValueError: low = low.expand_as(val) high = high.expand_as(val) val[val < low] = low[val < low] val[val > high] = high[val > high] except RuntimeError: low = low.expand_as(val) high = high.expand_as(val) val[val < low] = low[val < low] val[val > high] = high[val > high] return val
[docs] def is_in(self, val: torch.Tensor) -> bool: try: return (val >= and ( val <= ).all() except RuntimeError as err: if "The size of tensor a" in str(err): warnings.warn(f"Got a shape mismatch: {str(err)}") return False raise err
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device elif dest is None: return self else: dest_dtype = self.dtype dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self return self.__class__(,, shape=self.shape, device=dest_device, dtype=dest_dtype, ) def clone(self) -> BoundedTensorSpec: return self.__class__(,, shape=self.shape, device=self.device, dtype=self.dtype, ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" if _is_nested_list(idx): raise NotImplementedError( "Pending resolution of" ) indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) # Expand is required as pytorch.tensor indexing return self.__class__([idx].clone().expand(indexed_shape),[idx].clone().expand(indexed_shape), shape=indexed_shape, device=self.device, dtype=self.dtype, )
def _is_nested_list(index, notuple=False): if not notuple and isinstance(index, tuple): for idx in index: if _is_nested_list(idx, notuple=True): return True elif isinstance(index, list): for idx in index: if isinstance(idx, list): return True else: return False return False
[docs]class NonTensorSpec(TensorSpec): """A spec for non-tensor data.""" def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, device: Optional[DEVICE_TYPING] = None, dtype: torch.dtype | None = None, **kwargs, ): if isinstance(shape, int): shape = torch.Size([shape]) _, device = _default_dtype_and_device(None, device) domain = None super().__init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device elif dest is None: return self else: dest_dtype = self.dtype dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self return self.__class__(shape=self.shape, device=dest_device, dtype=None) def clone(self) -> NonTensorSpec: return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
[docs] def rand(self, shape): return NonTensorData(data=None, shape=self.shape, device=self.device)
[docs] def zero(self, shape): return NonTensorData(data=None, shape=self.shape, device=self.device)
def one(self, shape): return NonTensorData(data=None, shape=self.shape, device=self.device)
[docs] def is_in(self, val: torch.Tensor) -> bool: shape = torch.broadcast_shapes(self.shape, val.shape) return ( isinstance(val, NonTensorData) and val.shape == shape and val.device == self.device and val.dtype == self.dtype )
[docs] def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] shape = torch.Size(shape) if not all( (old == 1) or (old == new) for old, new in zip(self.shape, shape[-len(self.shape) :]) ): raise ValueError( f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}." ) return self.__class__(shape=shape, device=self.device, dtype=None)
def _reshape(self, shape): return self.__class__(shape=shape, device=self.device, dtype=self.dtype) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self.__class__( shape=shape, device=self.device, dtype=self.dtype, ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) def unbind(self, dim: int): orig_dim = dim if dim < 0: dim = len(self.shape) + dim if dim < 0: raise ValueError( f"Cannot unbind along dim {orig_dim} with shape {self.shape}." ) shape = tuple(s for i, s in enumerate(self.shape) if i != dim) return tuple( self.__class__( shape=shape, device=self.device, dtype=self.dtype, ) for i in range(self.shape[dim]) )
[docs]@dataclass(repr=False) class UnboundedContinuousTensorSpec(TensorSpec): """An unbounded continuous tensor spec. Args: device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors (should be an floating point dtype such as float, double etc.) """ # SPEC_HANDLED_FUNCTIONS = {} def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = None, **kwargs, ): if isinstance(shape, int): shape = torch.Size([shape]) dtype, device = _default_dtype_and_device(dtype, device) box = ( ContinuousBox( torch.as_tensor(-np.inf, device=device).expand(shape), torch.as_tensor(np.inf, device=device).expand(shape), ) if shape == _DEFAULT_SHAPE else None ) default_domain = "continuous" if dtype.is_floating_point else "discrete" domain = kwargs.pop("domain", default_domain) super().__init__( shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs ) def to( self, dest: Union[torch.dtype, DEVICE_TYPING] ) -> UnboundedContinuousTensorSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device elif dest is None: return self else: dest_dtype = self.dtype dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype) def clone(self) -> UnboundedContinuousTensorSpec: return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
[docs] def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) shape = [*shape, *self.shape] if self.dtype.is_floating_point: return torch.randn(shape, device=self.device, dtype=self.dtype) return torch.empty(shape, device=self.device, dtype=self.dtype).random_()
[docs] def is_in(self, val: torch.Tensor) -> bool: shape = torch.broadcast_shapes(self.shape, val.shape) return val.shape == shape and val.dtype == self.dtype
def _project(self, val: torch.Tensor) -> torch.Tensor: return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape)
[docs] def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( f"{self.__class__.__name__}.expand does not support negative shapes." ) if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): raise ValueError( f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
def _reshape(self, shape): return self.__class__(shape=shape, device=self.device, dtype=self.dtype) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self.__class__( shape=shape, device=self.device, dtype=self.dtype, ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) def unbind(self, dim: int): orig_dim = dim if dim < 0: dim = len(self.shape) + dim if dim < 0: raise ValueError( f"Cannot unbind along dim {orig_dim} with shape {self.shape}." ) shape = tuple(s for i, s in enumerate(self.shape) if i != dim) return tuple( self.__class__( shape=shape, device=self.device, dtype=self.dtype, ) for i in range(self.shape[dim]) ) def __eq__(self, other): # those specs are equivalent to a discrete spec if isinstance(other, UnboundedDiscreteTensorSpec): return ( UnboundedDiscreteTensorSpec( shape=self.shape, device=self.device, dtype=self.dtype, ) == other ) if isinstance(other, BoundedTensorSpec): minval, maxval = _minmax_dtype(self.dtype) minval = torch.as_tensor(minval).to(self.device, self.dtype) maxval = torch.as_tensor(maxval).to(self.device, self.dtype) return ( BoundedTensorSpec( shape=self.shape, high=maxval, low=minval, dtype=self.dtype, device=self.device, domain=self.domain, ) == other ) return super().__eq__(other)
[docs]@dataclass(repr=False) class UnboundedDiscreteTensorSpec(TensorSpec): """An unbounded discrete tensor spec. Args: device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors (should be an integer dtype such as long, uint8 etc.) """ # SPEC_HANDLED_FUNCTIONS = {} def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = None, ): if isinstance(shape, int): shape = torch.Size([shape]) dtype, device = _default_dtype_and_device(dtype, device) if dtype == torch.bool: min_value = False max_value = True else: if dtype.is_floating_point: min_value = torch.finfo(dtype).min max_value = torch.finfo(dtype).max else: min_value = torch.iinfo(dtype).min max_value = torch.iinfo(dtype).max space = ContinuousBox( torch.full(shape, min_value, device=device), torch.full(shape, max_value, device=device), ) super().__init__( shape=shape, space=space, device=device, dtype=dtype, domain="discrete", ) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device elif dest is None: return self else: dest_dtype = self.dtype dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype) def clone(self) -> UnboundedDiscreteTensorSpec: return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
[docs] def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) interval = - r = torch.rand(torch.Size([*shape, *interval.shape]), device=interval.device) r = r * interval r = + r r = return
[docs] def is_in(self, val: torch.Tensor) -> bool: shape = torch.broadcast_shapes(self.shape, val.shape) return val.shape == shape and val.dtype == self.dtype
[docs] def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( f"{self.__class__.__name__}.expand does not support negative shapes." ) if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): raise ValueError( f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
def _reshape(self, shape): return self.__class__(shape=shape, device=self.device, dtype=self.dtype) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self.__class__( shape=shape, device=self.device, dtype=self.dtype, ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) def unbind(self, dim: int): orig_dim = dim if dim < 0: dim = len(self.shape) + dim if dim < 0: raise ValueError( f"Cannot unbind along dim {orig_dim} with shape {self.shape}." ) shape = tuple(s for i, s in enumerate(self.shape) if i != dim) return tuple( self.__class__( shape=shape, device=self.device, dtype=self.dtype, ) for i in range(self.shape[dim]) ) def __eq__(self, other): # those specs are equivalent to a discrete spec if isinstance(other, UnboundedContinuousTensorSpec): return ( UnboundedContinuousTensorSpec( shape=self.shape, device=self.device, dtype=self.dtype, domain=self.domain, ) == other ) if isinstance(other, BoundedTensorSpec): return ( BoundedTensorSpec( shape=self.shape,,, dtype=self.dtype, device=self.device, domain=self.domain, ) == other ) return super().__eq__(other) def __ne__(self, other): # those specs are equivalent to a discrete spec if isinstance(other, UnboundedContinuousTensorSpec): return ( UnboundedContinuousTensorSpec( shape=self.shape, device=self.device, dtype=self.dtype, domain=self.domain, ) != other ) if isinstance(other, BoundedTensorSpec): return ( BoundedTensorSpec( shape=self.shape,,, dtype=self.dtype, device=self.device, domain=self.domain, ) != other ) return super().__ne__(other)
[docs]@dataclass(repr=False) class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): """A concatenation of one-hot discrete tensor spec. The last dimension of the shape (domain of the tensor elements) cannot be indexed. Args: nvec (iterable of integers): cardinality of each of the elements of the tensor. shape (torch.Size, optional): total shape of the sampled tensors. If provided, the last dimension must match sum(nvec). device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. Examples: >>> ts = MultiOneHotDiscreteTensorSpec((3,2,3)) >>> ts.is_in(torch.tensor([0,0,1, ... 0,1, ... 1,0,0])) True >>> ts.is_in(torch.tensor([1,0,1, ... 0,1, ... 1,0,0])) # False False """ # SPEC_HANDLED_FUNCTIONS = {} def __init__( self, nvec: Sequence[int], shape: Optional[torch.Size] = None, device=None, dtype=torch.bool, use_register=False, mask: torch.Tensor | None = None, ): self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) if shape is None: shape = torch.Size((sum(nvec),)) else: shape = torch.Size(shape) if shape[-1] != sum(nvec): raise ValueError( f"The last value of the shape must match sum(nvec) for transform of type {self.__class__}. " f"Got sum(nvec)={sum(nvec)} and shape={shape}." ) space = BoxList([DiscreteBox(n) for n in nvec]) self.use_register = use_register super(OneHotDiscreteTensorSpec, self).__init__( shape, space, device, dtype, domain="discrete", ) self.update_mask(mask) def update_mask(self, mask): if mask is not None: try: mask = mask.expand(*self.shape) except RuntimeError as err: raise RuntimeError("Cannot expand mask to the desired shape.") from err if mask.dtype != torch.bool: raise ValueError("Only boolean masks are accepted.") self.mask = mask def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device elif dest is None: return self else: dest_dtype = self.dtype dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self return self.__class__( nvec=deepcopy(self.nvec), shape=self.shape, device=dest_device, dtype=dest_dtype, if self.mask is not None else None, ) def clone(self) -> MultiOneHotDiscreteTensorSpec: return self.__class__( nvec=deepcopy(self.nvec), shape=self.shape, device=self.device, dtype=self.dtype, mask=self.mask.clone() if self.mask is not None else None, ) def __eq__(self, other): if not hasattr(other, "mask"): return False mask_equal = (self.mask is None and other.mask is None) or ( isinstance(self.mask, torch.Tensor) and isinstance(other.mask, torch.Tensor) and (self.mask.shape == other.mask.shape) and (self.mask == other.mask).all() ) return ( type(self) == type(other) and self.shape == other.shape and == and self.device == other.device and self.dtype == other.dtype and self.domain == other.domain and self.use_register == other.use_register and mask_equal )
[docs] def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] else: shape = torch.Size([*shape, *self.shape[:-1]]) mask = self.mask if mask is None: x = [ torch.nn.functional.one_hot( torch.randint( space.n, ( *shape, 1, ), device=self.device, ), space.n, ).to(self.dtype) for space in ], -1, ).squeeze(-2) return x mask = mask.expand(*shape, mask.shape[-1]) mask_splits = torch.split(mask, [space.n for space in], -1) out = [] for _mask in mask_splits: if mask.ndim > 2: mask_flat = torch.flatten(_mask, 0, -2) else: mask_flat = _mask shape_out = _mask.shape[:-1] m = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) m = torch.nn.functional.one_hot(m, _mask.shape[-1]).to(self.dtype) out.append(m) return, -1)
[docs] def encode( self, val: Union[np.ndarray, torch.Tensor], *, ignore_device: bool = False ) -> torch.Tensor: if not isinstance(val, torch.Tensor): if not ignore_device: val = torch.tensor(val, device=self.device) else: val = torch.as_tensor(val) x = [] for v, space in zip(val.unbind(-1), if not (v < space.n).all(): raise RuntimeError( f"value {v} is greater than the allowed max {space.n}" ) x.append( super(MultiOneHotDiscreteTensorSpec, self).encode( v, space, ignore_device=ignore_device ) ) return, -1).reshape(self.shape)
def _split(self, val: torch.Tensor) -> Optional[torch.Tensor]: split_sizes = [space.n for space in] if val.ndim < 1 or val.shape[-1] != sum(split_sizes): return None return val.split(split_sizes, dim=-1)
[docs] def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: if not isinstance(index, torch.Tensor): raise ValueError( f"Only tensors are allowed for indexing using" f" {self.__class__.__name__}.index(...)" ) indices = self._split(index) tensor_to_index = self._split(tensor_to_index) out = [] for _index, _tensor_to_index in zip(indices, tensor_to_index): _index = _index.nonzero().squeeze() _index = _index.expand((*_tensor_to_index.shape[:-1], _index.shape[-1])) out.append(_tensor_to_index.gather(-1, _index)) return, -1)
[docs] def is_in(self, val: torch.Tensor) -> bool: vals = self._split(val) if vals is None: return False return all(spec.is_in(val) for val, spec in zip(vals, self._split_self()))
def _project(self, val: torch.Tensor) -> torch.Tensor: vals = self._split(val) return [spec._project(val) for val, spec in zip(vals, self._split_self())], -1 ) def _split_self(self): result = [] device = self.device dtype = self.dtype use_register = self.use_register mask = ( self.mask.split([space.n for space in], -1) if self.mask is not None else [None] * len( ) for _mask, space in zip(mask, n = space.n shape = self.shape[:-1] + (n,) result.append( OneHotDiscreteTensorSpec( n=n, shape=shape, device=device, dtype=dtype, use_register=use_register, mask=_mask, ) ) return result
[docs] def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: """Converts a given one-hot tensor in categorical format. Args: val (torch.Tensor, optional): One-hot tensor to convert in categorical format. safe (bool): boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable. Returns: The categorical tensor. """ if safe is None: safe = _CHECK_SPEC_ENCODE if safe: self.assert_is_in(val) vals = self._split(val) return torch.stack([val.long().argmax(-1) for val in vals], -1)
[docs] def to_categorical_spec(self) -> MultiDiscreteTensorSpec: """Converts the spec to the equivalent categorical spec.""" return MultiDiscreteTensorSpec( [_space.n for _space in], device=self.device, shape=[*self.shape[:-1], len(], mask=self.mask, )
[docs] def expand(self, *shape): nvecs = [space.n for space in] if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( f"{self.__class__.__name__}.expand does not support negative shapes." ) if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): raise ValueError( f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) mask = self.mask.expand(shape) if self.mask is not None else None return self.__class__( nvec=nvecs, shape=shape, device=self.device, dtype=self.dtype, mask=mask, )
def _reshape(self, shape): nvecs = [space.n for space in] mask = self.mask.reshape(shape) if self.mask is not None else None return self.__class__( nvec=nvecs, shape=shape, device=self.device, dtype=self.dtype, mask=mask, ) def _unflatten(self, dim, sizes): nvecs = [space.n for space in] shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape mask = self.mask.reshape(shape) if self.mask is not None else None return self.__class__( nvec=nvecs, shape=shape, device=self.device, dtype=self.dtype, mask=mask, )
[docs] def squeeze(self, dim=None): if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _squeezed_shape(self.shape, dim) if shape is None: return self mask = self.mask.reshape(shape) if self.mask is not None else None return self.__class__( nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype, mask=mask, )
def unsqueeze(self, dim: int): if dim in (len(self.shape), -1): raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _unsqueezed_shape(self.shape, dim) mask = self.mask.reshape(shape) if self.mask is not None else None return self.__class__( nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype, mask=mask ) def unbind(self, dim: int): if dim in (len(self.shape), -1): raise ValueError(f"Final dimension of {type(self)} must remain unchanged") orig_dim = dim if dim < 0: dim = len(self.shape) + dim if dim < 0: raise ValueError( f"Cannot unbind along dim {orig_dim} with shape {self.shape}." ) shape = tuple(s for i, s in enumerate(self.shape) if i != dim) mask = self.mask if mask is None: mask = (None,) * self.shape[dim] else: mask = mask.unbind(dim) return tuple( self.__class__( nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype, mask=mask[i], ) for i in range(self.shape[dim]) ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index. The last dimension of the spec corresponding to the domain of the tensor elements is non-indexable. """ indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__( nvec=self.nvec, shape=torch.Size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, )
[docs]class DiscreteTensorSpec(TensorSpec): """A discrete tensor spec. An alternative to OneHotTensorSpec for categorical variables in TorchRL. Instead of using multiplication, categorical variables perform indexing which can speed up computation and reduce memory cost for large categorical variables. The last dimension of the spec (length n of the binary vector) cannot be indexed Example: >>> batch, size = 3, 4 >>> action_value = torch.arange(batch*size) >>> action_value = action_value.view(batch, size).to(torch.float) >>> action = torch.argmax(action_value, dim=-1).to(torch.long) >>> chosen_action_value = action_value[range(batch), action] >>> print(chosen_action_value) tensor([ 3., 7., 11.]) Args: n (int): number of possible outcomes. shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])". device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. """ shape: torch.Size space: DiscreteBox device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" # SPEC_HANDLED_FUNCTIONS = {} def __init__( self, n: int, shape: torch.Size | None = None, device: DEVICE_TYPING | None = None, dtype: str | torch.dtype = torch.long, mask: torch.Tensor | None = None, ): if shape is None: shape = torch.Size([]) dtype, device = _default_dtype_and_device(dtype, device) space = DiscreteBox(n) super().__init__( shape=shape, space=space, device=device, dtype=dtype, domain="discrete" ) self.update_mask(mask) @property def n(self): return def update_mask(self, mask): if mask is not None: try: mask = mask.expand(*self.shape, except RuntimeError as err: raise RuntimeError("Cannot expand mask to the desired shape.") from err if mask.dtype != torch.bool: raise ValueError("Only boolean masks are accepted.") self.mask = mask
[docs] def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) if self.mask is None: return torch.randint( 0,, torch.Size([*shape, *self.shape]), device=self.device, dtype=self.dtype, ) mask = self.mask mask = mask.expand(*shape, *mask.shape) if mask.ndim > 2: mask_flat = torch.flatten(mask, 0, -2) else: mask_flat = mask shape_out = mask.shape[:-1] out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) return out
def _project(self, val: torch.Tensor) -> torch.Tensor: if val.dtype not in (, torch.long): val = torch.round(val) if self.mask is None: return val.clamp_(min=0, - 1) shape = self.mask.shape shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) oob = ~gathered.all(-1) new_val = torch.multinomial(mask_expand[oob].float(), 1).squeeze(-1) val = torch.masked_scatter(val, oob, new_val) return val
[docs] def is_in(self, val: torch.Tensor) -> bool: if self.mask is None: return (0 <= val).all() and (val < shape = self.mask.shape shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) return gathered.all()
def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) return self.__class__(, shape=indexed_shape, device=self.device, dtype=self.dtype, ) def __eq__(self, other): if not hasattr(other, "mask"): return False mask_equal = (self.mask is None and other.mask is None) or ( isinstance(self.mask, torch.Tensor) and isinstance(other.mask, torch.Tensor) and (self.mask.shape == other.mask.shape) and (self.mask == other.mask).all() ) return ( type(self) == type(other) and self.shape == other.shape and == and self.device == other.device and self.dtype == other.dtype and self.domain == other.domain and mask_equal )
[docs] def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict: if safe is None: safe = _CHECK_SPEC_ENCODE # if not val.shape and not safe: # return val.item() return super().to_numpy(val, safe)
[docs] def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: """Encodes a discrete tensor from the spec domain into its one-hot correspondent. Args: val (torch.Tensor, optional): Tensor to one-hot encode. safe (bool): boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable. Returns: The one-hot encoded tensor. """ if safe is None: safe = _CHECK_SPEC_ENCODE if safe: self.assert_is_in(val) return torch.nn.functional.one_hot(val,
[docs] def to_one_hot_spec(self) -> OneHotDiscreteTensorSpec: """Converts the spec to the equivalent one-hot spec.""" shape = [*self.shape,] return OneHotDiscreteTensorSpec(, shape=shape, device=self.device, )
[docs] def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( f"{self.__class__.__name__}.expand does not support negative shapes." ) if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): raise ValueError( f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) return self.__class__(, shape=shape, device=self.device, dtype=self.dtype )
def _reshape(self, shape): return self.__class__(, shape=shape, device=self.device, dtype=self.dtype ) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self.__class__(, shape=shape, device=self.device, dtype=self.dtype )
[docs] def squeeze(self, dim=None): shape = _squeezed_shape(self.shape, dim) mask = self.mask if mask is not None: mask = mask.view(*shape, mask.shape[-1]) if shape is None: return self return self.__class__(, shape=shape, device=self.device, dtype=self.dtype, mask=mask, )
def unsqueeze(self, dim: int): shape = _unsqueezed_shape(self.shape, dim) mask = self.mask if mask is not None: mask = mask.view(*shape, mask.shape[-1]) return self.__class__(, shape=shape, device=self.device, dtype=self.dtype, mask=mask, ) def unbind(self, dim: int): orig_dim = dim if dim < 0: dim = len(self.shape) + dim if dim < 0: raise ValueError( f"Cannot unbind along dim {orig_dim} with shape {self.shape}." ) shape = tuple(s for i, s in enumerate(self.shape) if i != dim) mask = self.mask if mask is None: mask = (None,) * self.shape[dim] else: mask = mask.unbind(dim) return tuple( self.__class__(, shape=shape, device=self.device, dtype=self.dtype, mask=mask[i], ) for i in range(self.shape[dim]) ) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device elif dest is None: return self else: dest_dtype = self.dtype dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self return self.__class__(, shape=self.shape, device=dest_device, dtype=dest_dtype ) def clone(self) -> DiscreteTensorSpec: return self.__class__(, shape=self.shape, device=self.device, dtype=self.dtype, mask=self.mask.clone() if self.mask is not None else None, )
[docs]@dataclass(repr=False) class BinaryDiscreteTensorSpec(DiscreteTensorSpec): """A binary discrete tensor spec. Args: n (int): length of the binary vector. shape (torch.Size, optional): total shape of the sampled tensors. If provided, the last dimension must match n. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long. Examples: >>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool) >>> print( """ def __init__( self, n: int, shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Union[str, torch.dtype] = torch.int8, ): if shape is None or not len(shape): shape = torch.Size((n,)) else: shape = torch.Size(shape) if shape[-1] != n: raise ValueError( f"The last value of the shape must match n for spec {self.__class__}. " f"Got n={n} and shape={shape}." ) super().__init__(n=2, shape=shape, device=device, dtype=dtype)
[docs] def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( f"{self.__class__.__name__}.expand does not support negative shapes." ) if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): raise ValueError( f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) return self.__class__( n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype )
def _reshape(self, shape): return self.__class__( n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype ) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self.__class__( n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype )
[docs] def squeeze(self, dim=None): shape = _squeezed_shape(self.shape, dim) if shape is None: return self return self.__class__( n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype )
def unsqueeze(self, dim: int): shape = _unsqueezed_shape(self.shape, dim) return self.__class__( n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype ) def unbind(self, dim: int): if dim in (len(self.shape) - 1, -1): raise ValueError(f"Final dimension of {type(self)} must remain unchanged") orig_dim = dim if dim < 0: dim = len(self.shape) + dim if dim < 0: raise ValueError( f"Cannot unbind along dim {orig_dim} with shape {self.shape}." ) shape = tuple(s for i, s in enumerate(self.shape) if i != dim) return tuple( self.__class__( n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype ) for i in range(self.shape[dim]) ) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device elif dest is None: return self else: dest_dtype = self.dtype dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self return self.__class__( n=self.shape[-1], shape=self.shape, device=dest_device, dtype=dest_dtype ) def clone(self) -> BinaryDiscreteTensorSpec: return self.__class__( n=self.shape[-1], shape=self.shape, device=self.device, dtype=self.dtype, ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index. The last dimension of the spec (length n of the binary vector) cannot be indexed. """ indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__( n=self.shape[-1], shape=torch.Size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, ) def __eq__(self, other): if not isinstance(other, BinaryDiscreteTensorSpec): if isinstance(other, DiscreteTensorSpec): return ( other.n == 2 and other.device == self.device and other.shape == self.shape and other.dtype == self.dtype ) return False return super().__eq__(other)
[docs]@dataclass(repr=False) class MultiDiscreteTensorSpec(DiscreteTensorSpec): """A concatenation of discrete tensor spec. Args: nvec (iterable of integers or torch.Tensor): cardinality of each of the elements of the tensor. Can have several axes. shape (torch.Size, optional): total shape of the sampled tensors. If provided, the last m dimensions must match nvec.shape. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. Examples: >>> ts = MultiDiscreteTensorSpec((3, 2, 3)) >>> ts.is_in(torch.tensor([2, 0, 1])) True >>> ts.is_in(torch.tensor([2, 2, 1])) False """ # SPEC_HANDLED_FUNCTIONS = {} def __init__( self, nvec: Union[Sequence[int], torch.Tensor, int], shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.long, mask: torch.Tensor | None = None, ): if not isinstance(nvec, torch.Tensor): nvec = torch.tensor(nvec) if nvec.ndim < 1: nvec = nvec.unsqueeze(0) self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) if shape is None: shape = nvec.shape else: shape = torch.Size(shape) if shape[-1] != nvec.shape[-1]: raise ValueError( f"The last value of the shape must match nvec.shape[-1] for transform of type {self.__class__}. " f"Got nvec.shape[-1]={sum(nvec)} and shape={shape}." ) self.nvec = self.nvec.expand(shape) space = BoxList.from_nvec(self.nvec) super(DiscreteTensorSpec, self).__init__( shape, space, device, dtype, domain="discrete" ) self.update_mask(mask) def update_mask(self, mask): if mask is not None: try: mask = mask.expand(*self.shape[:-1], mask.shape[-1]) except RuntimeError as err: raise RuntimeError("Cannot expand mask to the desired shape.") from err if mask.dtype != torch.bool: raise ValueError("Only boolean masks are accepted.") self.mask = mask def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device elif dest is None: return self else: dest_dtype = self.dtype dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self mask = if self.mask is not None else None return self.__class__(, shape=None, device=dest_device, dtype=dest_dtype, mask=mask, ) def __eq__(self, other): if not hasattr(other, "mask"): return False mask_equal = (self.mask is None and other.mask is None) or ( isinstance(self.mask, torch.Tensor) and isinstance(other.mask, torch.Tensor) and (self.mask.shape == other.mask.shape) and (self.mask == other.mask).all() ) return ( type(self) == type(other) and self.shape == other.shape and == and self.device == other.device and self.dtype == other.dtype and self.domain == other.domain and mask_equal ) def clone(self) -> MultiDiscreteTensorSpec: return self.__class__( nvec=self.nvec.clone(), shape=None, device=self.device, dtype=self.dtype, mask=self.mask.clone() if self.mask is not None else None, ) def _rand(self, space: Box, shape: torch.Size, i: int): x = [] for _s in space: if isinstance(_s, BoxList): x.append(self._rand(_s, shape[:-1], i - 1)) else: x.append( torch.randint( 0, _s.n, shape, device=self.device, dtype=self.dtype, ) ) return torch.stack(x, -1)
[docs] def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if self.mask is not None: splits = self._split_self() return torch.stack([split.rand(shape) for split in splits], -1) if shape is None: shape = self.shape[:-1] else: shape = ( *shape, *self.shape[:-1], ) x = self._rand(, shape=shape, i=self.nvec.ndim) if self.shape == torch.Size([1]): x = x.squeeze(-1) return x
def _split_self(self): result = [] device = self.device dtype = self.dtype nvec = self.nvec if nvec.ndim > 1: nvec = torch.flatten(nvec, 0, -2)[0] if (self.nvec != nvec).any(): raise ValueError( f"Only homogeneous MultiDiscrete specs can be masked, got nvec={self.nvec}." ) nvec = nvec.tolist() mask = ( self.mask.split(nvec, -1) if self.mask is not None else [None] * len( ) for n, _mask in zip(nvec, mask): shape = self.shape[:-1] result.append( DiscreteTensorSpec( n=n, shape=shape, device=device, dtype=dtype, mask=_mask ) ) return result def _project(self, val: torch.Tensor) -> torch.Tensor: if self.mask is not None: return torch.stack( [ spec._project(_val) for (_val, spec) in zip(val.unbind(-1), self._split_self()) ], -1, ) val_is_scalar = val.ndim < 1 if val_is_scalar: val = val.unsqueeze(0) if not self.dtype.is_floating_point: val = torch.round(val) val = val.type(self.dtype) val[val >= self.nvec] = (self.nvec.expand_as(val)[val >= self.nvec] - 1).type( self.dtype ) return val.squeeze(0) if val_is_scalar else val
[docs] def is_in(self, val: torch.Tensor) -> bool: if self.mask is not None: return all( spec.is_in(_val) for (_val, spec) in zip(val.unbind(-1), self._split_self()) ) if val.ndim < 1: val = val.unsqueeze(0) val_have_wrong_dim = ( self.shape != torch.Size([1]) and val.shape[-len(self.shape) :] != self.shape ) if self.dtype != val.dtype or len(self.shape) > val.ndim or val_have_wrong_dim: return False val_device = val.device return ( ( (val >= torch.zeros(self.nvec.size(), device=val_device)) & (val < ) .all() .item() )
[docs] def to_one_hot( self, val: torch.Tensor, safe: bool = None ) -> Union[MultiOneHotDiscreteTensorSpec, torch.Tensor]: """Encodes a discrete tensor from the spec domain into its one-hot correspondent. Args: val (torch.Tensor, optional): Tensor to one-hot encode. safe (bool): boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable. Returns: The one-hot encoded tensor. """ if safe is None: safe = _CHECK_SPEC_ENCODE if safe: self.assert_is_in(val) return [ torch.nn.functional.one_hot(val[..., i], n) for i, n in enumerate(self.nvec) ], -1, ).to(self.device)
[docs] def to_one_hot_spec(self) -> MultiOneHotDiscreteTensorSpec: """Converts the spec to the equivalent one-hot spec.""" nvec = [_space.n for _space in] return MultiOneHotDiscreteTensorSpec( nvec, device=self.device, shape=[*self.shape[:-1], sum(nvec)], mask=self.mask, )
[docs] def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( f"{self.__class__.__name__}.expand does not support negative shapes." ) if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): raise ValueError( f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) mask = ( self.mask.expand(*shape, self.mask.shape[-1]) if self.mask is not None else None ) return self.__class__( nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype, mask=mask, )
def _reshape(self, shape): mask = ( self.mask.reshape(*shape, self.mask.shape[-1]) if self.mask is not None else None ) return self.__class__( nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype, mask=mask, ) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self._reshape(shape)
[docs] def squeeze(self, dim: int | None = None): if self.shape[-1] == 1 and dim in (len(self.shape), -1, None): raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _squeezed_shape(self.shape, dim) if shape is None: return self if dim is None: nvec = self.nvec.squeeze() else: nvec = self.nvec.squeeze(dim) mask = self.mask if mask is not None: mask = mask.view(*shape[:-1], mask.shape[-1]) return self.__class__( nvec=nvec, shape=shape, device=self.device, dtype=self.dtype, mask=mask )
def unsqueeze(self, dim: int): if dim in (len(self.shape), -1): raise ValueError(f"Final dimension of {type(self)} must remain unchanged") shape = _unsqueezed_shape(self.shape, dim) nvec = self.nvec.unsqueeze(dim) mask = self.mask if mask is not None: mask = mask.view(*shape[:-1], mask.shape[-1]) return self.__class__( nvec=nvec, shape=shape, device=self.device, dtype=self.dtype, mask=mask, ) def unbind(self, dim: int): if dim in (len(self.shape), -1): raise ValueError(f"Final dimension of {type(self)} must remain unchanged") orig_dim = dim if dim < 0: dim = len(self.shape) + dim if dim < 0: raise ValueError( f"Cannot unbind along dim {orig_dim} with shape {self.shape}." ) shape = tuple(s for i, s in enumerate(self.shape) if i != dim) mask = self.mask nvec = self.nvec.unbind(dim) if mask is not None: mask = mask.unbind(dim) else: mask = (None,) * self.shape[dim] return tuple( self.__class__( nvec=nvec[i], shape=shape, device=self.device, dtype=self.dtype, mask=mask[i], ) for i in range(self.shape[dim]) ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" if _is_nested_list(idx): raise NotImplementedError( "Pending resolution of" ) return self.__class__( nvec=self.nvec[idx].clone(), shape=None, device=self.device, dtype=self.dtype, )
[docs]class CompositeSpec(TensorSpec): """A composition of TensorSpecs. Args: *args: if an unnamed argument is passed, it must be a dictionary with keys matching the expected keys to be found in the :obj:`CompositeSpec` object. This is useful to build nested CompositeSpecs with tuple indices. **kwargs (key (str): value (TensorSpec)): dictionary of tensorspecs to be stored. Values can be None, in which case is_in will be assumed to be ``True`` for the corresponding tensors, and :obj:`project()` will have no effect. `spec.encode` cannot be used with missing values. Attributes: device (torch.device or None): if not specified, the device of the composite spec is ``None`` (as it is the case for TensorDicts). A non-none device constraints all leaves to be of the same device. On the other hand, a ``None`` device allows leaves to have different devices. Defaults to ``None``. shape (torch.Size): the leading shape of all the leaves. Equivalent to the batch-size of the corresponding tensordicts. Examples: >>> pixels_spec = BoundedTensorSpec( ... torch.zeros(3,32,32), ... torch.ones(3, 32, 32)) >>> observation_vector_spec = BoundedTensorSpec(torch.zeros(33), ... torch.ones(33)) >>> composite_spec = CompositeSpec( ... pixels=pixels_spec, ... observation_vector=observation_vector_spec) >>> td = TensorDict({"pixels": torch.rand(10,3,32,32), ... "observation_vector": torch.rand(10,33)}, batch_size=[10]) >>> print("td (rand) is within bounds: ", composite_spec.is_in(td)) td (rand) is within bounds: True >>> td = TensorDict({"pixels": torch.randn(10,3,32,32), ... "observation_vector": torch.randn(10,33)}, batch_size=[10]) >>> print("td (randn) is within bounds: ", composite_spec.is_in(td)) td (randn) is within bounds: False >>> td_project = composite_spec.project(td) >>> print("td modification done in place: ", td_project is td) td modification done in place: True >>> print("check td is within bounds after projection: ", ... composite_spec.is_in(td_project)) check td is within bounds after projection: True >>> print("random td: ", composite_spec.rand([3,])) random td: TensorDict( fields={ observation_vector: Tensor(torch.Size([3, 33]), dtype=torch.float32), pixels: Tensor(torch.Size([3, 3, 32, 32]), dtype=torch.float32)}, batch_size=torch.Size([3]), device=None, is_shared=False) Examples: >>> # we can build a nested composite spec using unnamed arguments >>> print(CompositeSpec({("a", "b"): None, ("a", "c"): None})) CompositeSpec( a: CompositeSpec( b: None, c: None)) CompositeSpec supports nested indexing: >>> spec = CompositeSpec(obs=None) >>> spec["nested", "x"] = None >>> print(spec) CompositeSpec( nested: CompositeSpec( x: None), x: None) """ shape: torch.Size domain: str = "composite" SPEC_HANDLED_FUNCTIONS = {} @classmethod def __new__(cls, *args, **kwargs): cls._device = None cls._locked = False return super().__new__(cls) @property def shape(self): return self._shape @shape.setter def shape(self, value: torch.Size): if self.locked: raise RuntimeError("Cannot modify shape of locked composite spec.") for key, spec in self.items(): if isinstance(spec, CompositeSpec): if spec.shape[: len(value)] != value: spec.shape = value elif spec is not None: if spec.shape[: len(value)] != value: raise ValueError( f"The shape of the spec and the CompositeSpec mismatch during shape resetting: the " f"{self.ndim} first dimensions should match but got self['{key}'].shape={spec.shape} and " f"CompositeSpec.shape={self.shape}." ) self._shape = torch.Size(value)
[docs] def is_empty(self): """Whether the composite spec contains specs or not.""" return len(self._specs) == 0
@property def ndim(self): return self.ndimension() def ndimension(self): return len(self.shape) def set(self, name, spec): if self.locked: raise RuntimeError("Cannot modify a locked CompositeSpec.") if spec is not None: shape = spec.shape if shape[: self.ndim] != self.shape: raise ValueError( "The shape of the spec and the CompositeSpec mismatch: the first " f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " f"CompositeSpec.shape={self.shape}." ) self._specs[name] = spec def __init__(self, *args, shape=None, device=None, **kwargs): if shape is None: # Should we do this? Other specs have a default empty shape, maybe it would make sense to keep it # optional for composite (for clarity and easiness of use). # warnings.warn("shape=None for CompositeSpec will soon be deprecated. Make sure you set the " # "batch size of your CompositeSpec as you would do for a tensordict.") shape = [] self._shape = torch.Size(shape) self._specs = {} for key, value in kwargs.items(): self.set(key, value) _device = torch.device(device) if device is not None else device if len(kwargs): for key, item in self.items(): if item is None: continue if ( isinstance(item, CompositeSpec) and item.device is None and _device is not None ): item = item.clone().to(_device) elif (_device is not None) and (item.device != _device): raise RuntimeError( f"Setting a new attribute ({key}) on another device " f"({item.device} against {_device}). All devices of " "CompositeSpec must match." ) self._device = _device if len(args): if len(args) > 1: raise RuntimeError( "Got multiple arguments, when at most one is expected for CompositeSpec." ) argdict = args[0] if not isinstance(argdict, (dict, CompositeSpec)): raise RuntimeError( f"Expected a dictionary of specs, but got an argument of type {type(argdict)}." ) for k, item in argdict.items(): if isinstance(item, dict): item = CompositeSpec(item, shape=shape, device=_device) self[k] = item @property def device(self) -> DEVICE_TYPING: return self._device @device.setter def device(self, device: DEVICE_TYPING): if device is None and self._device is not None: raise RuntimeError( "To erase the device of a composite spec, call " "spec.clear_device_()." ) device = torch.device(device)
[docs] def clear_device_(self): """Clears the device of the CompositeSpec.""" self._device = None for spec in self._specs.values(): spec.clear_device_() return self
def __getitem__(self, idx): """Indexes the current CompositeSpec based on the provided index.""" if isinstance(idx, (str, tuple)): idx_unravel = unravel_key(idx) else: idx_unravel = () if idx_unravel: if isinstance(idx_unravel, tuple): return self[idx[0]][idx[1:]] if idx_unravel in {"shape", "device", "dtype", "space"}: raise AttributeError(f"CompositeSpec has no key {idx_unravel}") return self._specs[idx_unravel] indexed_shape = _shape_indexing(self.shape, idx) indexed_specs = {} for k, v in self._specs.items(): _idx = idx if isinstance(idx, tuple): protected_dims = 0 if any( isinstance(v, spec_class) for spec_class in [ BinaryDiscreteTensorSpec, MultiDiscreteTensorSpec, OneHotDiscreteTensorSpec, ] ): protected_dims = 1 # TensorSpecs dims which are not part of the composite shape cannot be indexed _idx = idx + (slice(None),) * ( len(v.shape) - len(self.shape) - protected_dims ) indexed_specs[k] = v[_idx] if v is not None else None try: device = self.device except RuntimeError: device = self._device return self.__class__( indexed_specs, shape=indexed_shape, device=device, )
[docs] def get(self, item, default=NO_DEFAULT): """Gets an item from the CompositeSpec. If the item is absent, a default value can be passed. """ try: return self[item] except KeyError: if item is not NO_DEFAULT: return default raise
def __setitem__(self, key, value): if isinstance(key, tuple) and len(key) > 1: if key[0] not in self.keys(True): self[key[0]] = CompositeSpec(shape=self.shape, device=self.device) self[key[0]][key[1:]] = value return elif isinstance(key, tuple): self[key[0]] = value return elif not isinstance(key, str): raise TypeError(f"Got key of type {type(key)} when a string was expected.") if key in {"shape", "device", "dtype", "space"}: raise AttributeError(f"CompositeSpec[{key}] cannot be set") if isinstance(value, dict): value = CompositeSpec(value, device=self._device, shape=self.shape) if ( value is not None and self.device is not None and value.device != self.device ): if isinstance(value, CompositeSpec) and value.device is None: value = value.clone().to(self.device) else: raise RuntimeError( f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). " f"All devices of CompositeSpec must match." ) self.set(key, value) def __iter__(self): yield from self._specs def __delitem__(self, key: str) -> None: if isinstance(key, tuple) and len(key) > 1: spec = self[key[:-1]] del spec[key[-1]] return elif isinstance(key, tuple): del self._specs[key[0]] return elif not isinstance(key, str): raise TypeError( f"Got key of type {type(key)} when a string or a tuple of strings was expected." ) if key in {"shape", "device", "dtype", "space"}: raise AttributeError(f"CompositeSpec has no key {key}") del self._specs[key]
[docs] def encode( self, vals: Dict[str, Any], *, ignore_device: bool = False ) -> Dict[str, torch.Tensor]: if isinstance(vals, TensorDict): out = vals.empty() # create and empty tensordict similar to vals else: out = TensorDict({}, torch.Size([]), _run_checks=False) for key, item in vals.items(): if item is None: raise RuntimeError( "CompositeSpec.encode cannot be used with missing values." ) try: out[key] = self[key].encode(item, ignore_device=ignore_device) except KeyError: raise KeyError( f"The CompositeSpec instance with keys {self.keys()} does not have a '{key}' key." ) return out
def __repr__(self) -> str: sub_str = [ indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items() ] sub_str = ",\n".join(sub_str) return f"CompositeSpec(\n{sub_str}, device={self._device}, shape={self.shape})"
[docs] def type_check( self, value: Union[torch.Tensor, TensorDictBase], selected_keys: Union[str, Optional[Sequence[str]]] = None, ): if isinstance(value, torch.Tensor) and isinstance(selected_keys, str): value = {selected_keys: value} selected_keys = [selected_keys] for _key in self.keys(): if self[_key] is not None and ( selected_keys is None or _key in selected_keys ): self._specs[_key].type_check(value[_key], _key)
[docs] def is_in(self, val: Union[dict, TensorDictBase]) -> bool: for key, item in self._specs.items(): if item is None or (isinstance(item, CompositeSpec) and item.is_empty()): continue val_item = val.get(key) if not item.is_in(val_item): return False return True
[docs] def project(self, val: TensorDictBase) -> TensorDictBase: for key, item in self.items(): if item is None: continue _val = val.get(key) if not self._specs[key].is_in(_val): val.set(key, self._specs[key].project(_val)) return val
[docs] def rand(self, shape=None) -> TensorDictBase: if shape is None: shape = torch.Size([]) _dict = {} for key, item in self.items(): if item is not None: _dict[key] = item.rand(shape) return TensorDict( _dict, batch_size=torch.Size([*shape, *self.shape]), device=self._device, # No need to run checks since we know Composite is compliant with # TensorDict requirements _run_checks=False, )
[docs] def keys( self, include_nested: bool = False, leaves_only: bool = False, ) -> KeysView: """Keys of the CompositeSpec. The keys argument reflect those of :class:`tensordict.TensorDict`. Args: include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys :obj:`["next"]. Default is ``False``, i.e. nested keys will not be returned. leaves_only (bool, optional): if ``False``, the values returned will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys :obj:`["next", ("next", "obs")]`. Default is ``False``. """ return _CompositeSpecKeysView( self, include_nested=include_nested, leaves_only=leaves_only )
[docs] def items( self, include_nested: bool = False, leaves_only: bool = False, ) -> ItemsView: """Items of the CompositeSpec. Args: include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys :obj:`["next"]. Default is ``False``, i.e. nested keys will not be returned. leaves_only (bool, optional): if ``False``, the values returned will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys :obj:`["next", ("next", "obs")]`. Default is ``False``. """ if not include_nested and not leaves_only: yield from self._specs.items() else: yield from ( (key, self[key]) for key in self.keys( include_nested=include_nested, leaves_only=leaves_only ) )
[docs] def values( self, include_nested: bool = False, leaves_only: bool = False, ) -> ValuesView: """Values of the CompositeSpec. Args: include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys :obj:`["next"]. Default is ``False``, i.e. nested keys will not be returned. leaves_only (bool, optional): if ``False``, the values returned will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys :obj:`["next", ("next", "obs")]`. Default is ``False``. """ if not include_nested and not leaves_only: yield from self._specs.values() else: yield from ( self[key] for key in self.keys( include_nested=include_nested, leaves_only=leaves_only ) )
def _reshape(self, shape): _specs = { key: val.reshape((*shape, *val.shape[self.ndimension() :])) for key, val in self._specs.items() } return CompositeSpec(_specs, shape=shape) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self._reshape(shape) def __len__(self): return len(self.keys()) def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: if dest is None: return self if not isinstance(dest, (str, int, torch.device)): raise ValueError( "Only device casting is allowed with specs of type CompositeSpec." ) if self._device and self._device == torch.device(dest): return self _device = torch.device(dest) items = list(self.items()) kwargs = {} for key, value in items: if value is None: kwargs[key] = value continue kwargs[key] = return self.__class__(**kwargs, device=_device, shape=self.shape) def clone(self) -> CompositeSpec: try: device = self.device except RuntimeError: device = self._device return self.__class__( { key: item.clone() if item is not None else None for key, item in self.items() }, device=device, shape=self.shape, )
[docs] def empty(self): """Create a spec like self, but with no entries.""" try: device = self.device except RuntimeError: device = self._device return self.__class__( {}, device=device, shape=self.shape, )
[docs] def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: return {key: self[key].to_numpy(val) for key, val in val.items()}
[docs] def zero(self, shape=None) -> TensorDictBase: if shape is None: shape = torch.Size([]) try: device = self.device except RuntimeError: device = self._device return TensorDict( { key: self[key].zero(shape) for key in self.keys(True) if isinstance(key, str) and self[key] is not None }, torch.Size([*shape, *self.shape]), device=device, )
def __eq__(self, other): return ( type(self) is type(other) and self.shape == other.shape and self._device == other._device and set(self._specs.keys()) == set(other._specs.keys()) and all((self._specs[key] == spec) for (key, spec) in other._specs.items()) ) def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: for key, item in dict_or_spec.items(): if key in self.keys(True) and isinstance(self[key], CompositeSpec): self[key].update(item) continue try: if isinstance(item, TensorSpec) and item.device != self.device: item = deepcopy(item) if self.device is not None: item = except RuntimeError as err: if DEVICE_ERR_MSG in str(err): try: item_device = item.device self.device = item_device except RuntimeError as suberr: if DEVICE_ERR_MSG in str(suberr): pass else: raise suberr else: raise err self[key] = item return self
[docs] def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError("CompositeSpec.expand does not support negative shapes.") if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): raise ValueError( f"The last {self.ndim} of the expanded shape {shape} must match the" f"shape of the {self.__class__.__name__} spec in expand()." ) try: device = self.device except RuntimeError: device = self._device out = CompositeSpec( { key: value.expand((*shape, *value.shape[self.ndim :])) if value is not None else None for key, value in tuple(self.items()) }, shape=shape, device=device, ) return out
[docs] def squeeze(self, dim: int | None = None): if dim is not None: if dim < 0: dim += len(self.shape) shape = _squeezed_shape(self.shape, dim) if shape is None: return self try: device = self.device except RuntimeError: device = self._device return CompositeSpec( {key: value.squeeze(dim) for key, value in self.items()}, shape=shape, device=device, ) if self.shape.count(1) == 0: return self # we can't just recursively apply squeeze with dim=None because we don't want # to squeeze non-batch dims of the values. Instead we find the first dim in # the batch dims with size 1, squeeze that, then recurse on the root spec out = self.squeeze(self.shape.index(1)) return out.squeeze()
def unsqueeze(self, dim: int): if dim < 0: dim += len(self.shape) + 1 shape = _unsqueezed_shape(self.shape, dim) try: device = self.device except RuntimeError: device = self._device return CompositeSpec( { key: value.unsqueeze(dim) if value is not None else None for key, value in self.items() }, shape=shape, device=device, ) def unbind(self, dim: int): orig_dim = dim if dim < 0: dim = len(self.shape) + dim if dim < 0: raise ValueError( f"Cannot unbind along dim {orig_dim} with shape {self.shape}." ) shape = (s for i, s in enumerate(self.shape) if i != dim) unbound_vals = {key: val.unbind(dim) for key, val in self.items()} return tuple( self.__class__( {key: val[i] for key, val in unbound_vals.items()}, shape=shape, device=self.device, ) for i in range(self.shape[dim]) )
[docs] def lock_(self, recurse=False): """Locks the CompositeSpec and prevents modification of its content. This is only a first-level lock, unless specified otherwise through the ``recurse`` arg. Leaf specs can always be modified in place, but they cannot be replaced in their CompositeSpec parent. Examples: >>> shape = [3, 4, 5] >>> spec = CompositeSpec( ... a=CompositeSpec( ... b=CompositeSpec(shape=shape[:3], device="cpu"), shape=shape[:2] ... ), ... shape=shape[:1], ... ) >>> spec["a"] = spec["a"].clone() >>> recurse = False >>> spec.lock_(recurse=recurse) >>> try: ... spec["a"] = spec["a"].clone() ... except RuntimeError: ... print("failed!") failed! >>> try: ... spec["a", "b"] = spec["a", "b"].clone() ... print("succeeded!") ... except RuntimeError: ... print("failed!") succeeded! >>> recurse = True >>> spec.lock_(recurse=recurse) >>> try: ... spec["a", "b"] = spec["a", "b"].clone() ... print("succeeded!") ... except RuntimeError: ... print("failed!") failed! """ self._locked = True if recurse: for value in self.values(): if isinstance(value, CompositeSpec): value.lock_(recurse) return self
[docs] def unlock_(self, recurse=False): """Unlocks the CompositeSpec and allows modification of its content. This is only a first-level lock modification, unless specified otherwise through the ``recurse`` arg. """ self._locked = False if recurse: for value in self.values(): if isinstance(value, CompositeSpec): value.unlock_(recurse) return self
@property def locked(self): return self._locked
[docs]class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec): """A lazy representation of a stack of composite specs. Stacks composite specs together along one dimension. When random samples are drawn, a LazyStackedTensorDict is returned. Indexing is allowed but only along the stack dimension. This class is aimed to be used in multi-task and multi-agent settings, where heterogeneous specs may occur (same semantic but different shape). """ def update(self, dict) -> None: for key, item in dict.items(): if key in self.keys() and isinstance( item, (Dict, CompositeSpec, LazyStackedCompositeSpec) ): for spec, sub_item in zip(self._specs, item.unbind(self.dim)): spec[key].update(sub_item) continue self[key] = item return self def __eq__(self, other): if not isinstance(other, LazyStackedCompositeSpec): return False if len(self._specs) != len(other._specs): return False if self.stack_dim != other.stack_dim: return False if self.device != other.device: return False for _spec1, _spec2 in zip(self._specs, other._specs): if _spec1 != _spec2: return False return True
[docs] def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: if safe is None: safe = _CHECK_SPEC_ENCODE if safe: if val.shape[self.dim] != len(self._specs): raise ValueError( "Size of LazyStackedCompositeSpec and val differ along the " "stacking dimension" ) for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): spec.assert_is_in(v) return {key: self[key].to_numpy(val) for key, val in val.items()}
def __len__(self): return self.shape[0]
[docs] def values( self, include_nested: bool = False, leaves_only: bool = False, ): for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): yield self[key]
[docs] def items( self, include_nested: bool = False, leaves_only: bool = False, ): for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): yield key, self[key]
[docs] def keys( self, include_nested: bool = False, leaves_only: bool = False, ) -> KeysView: keys = self._specs[0].keys( include_nested=include_nested, leaves_only=leaves_only ) keys = set(keys) for spec in self._specs[1:]: keys = keys.intersection(spec.keys(include_nested, leaves_only)) return sorted(keys, key=str)
[docs] def project(self, val: TensorDictBase) -> TensorDictBase: vals = [] for spec, subval in zip(self._specs, val.unbind(self.dim)): if not spec.is_in(subval): vals.append(spec.project(subval)) else: vals.append(subval) res = LazyStackedTensorDict.maybe_dense_stack(vals, dim=self.dim) if not isinstance(val, LazyStackedTensorDict): res = res.to_tensordict() return res
[docs] def type_check( self, value: Union[torch.Tensor, TensorDictBase], selected_keys: Union[NestedKey, Optional[Sequence[NestedKey]]] = None, ): if selected_keys is None: if isinstance(value, torch.Tensor): raise ValueError( "value must be of type TensorDictBase when key is None" ) for spec, subvalue in zip(self._specs, value.unbind(self.dim)): spec.type_check(subvalue) else: if isinstance(value, torch.Tensor) and isinstance(selected_keys, str): value = {selected_keys: value} selected_keys = [selected_keys] for _key in self.keys(): if self[_key] is not None and _key in selected_keys: self[_key].type_check(value[_key], _key)
def __repr__(self) -> str: sub_str = ",\n".join( [indent(f"{k}: {repr(item)}", 4 * " ") for k, item in self.items()] ) sub_str = indent(f"fields={{\n{', '.join([sub_str])}}}", 4 * " ") exclusive_key_str = self.repr_exclusive_keys() device_str = indent(f"device={self._specs[0].device}", 4 * " ") shape_str = indent(f"shape={self.shape}", 4 * " ") stack_dim = indent(f"stack_dim={self.dim}", 4 * " ") string = ",\n".join( [sub_str, exclusive_key_str, device_str, shape_str, stack_dim] ) return f"LazyStackedCompositeSpec(\n{string})" def repr_exclusive_keys(self): keys = set(self.keys()) exclusive_keys = [ ",\n".join( [ indent(f"{k}: {repr(spec[k])}", 4 * " ") for k in spec.keys() if k not in keys ] ) for spec in self._specs ] exclusive_key_str = ",\n".join( [ indent(f"{i} ->\n{line}", 4 * " ") for i, line in enumerate(exclusive_keys) if line != "" ] ) return indent(f"exclusive_fields={{\n{exclusive_key_str}}}", 4 * " ")
[docs] def is_in(self, val) -> bool: for spec, subval in zip(self._specs, val.unbind(self.dim)): if not spec.is_in(subval): return False return True
def __delitem__(self, key: NestedKey): """Deletes a key from the stacked composite spec. This method will be executed if the key is present in at least one of the stacked specs, otherwise it will raise an error. Args: key (NestedKey): the key to delete. """ at_least_one_deletion = False for spec in self._specs: try: del spec[key] at_least_one_deletion = True except KeyError: continue if not at_least_one_deletion: raise KeyError( f"Key {key} must be present in at least one of the stacked specs" ) return self def __iter__(self): for k in self.keys(): yield self[k] def __setitem__(self, key: NestedKey, value): key = unravel_key(key) is_key = isinstance(key, str) or ( isinstance(key, tuple) and all(isinstance(_item, str) for _item in key) ) if is_key: self.set(key, value) else: raise ValueError( f"{self.__class__} expects str or tuple of str as key to set values " ) @property def device(self) -> DEVICE_TYPING: device = self.__dict__.get("_device", NO_DEFAULT) if device is NO_DEFAULT: devices = {spec.device for spec in self._specs} if len(devices) == 1: device = list(devices)[0] elif len(devices) == 2: device0, device1 = devices if device0 is None: device = device1 elif device1 is None: device = device0 else: device = None else: device = None self.__dict__["_device"] = device return device @property def ndim(self): return self.ndimension() def ndimension(self): return len(self.shape) def set(self, name, spec): for sub_spec, sub_item in zip(self._specs, spec.unbind(self.dim)): sub_spec[name] = sub_item @property def shape(self): shape = list(self._specs[0].shape) dim = self.dim if dim < 0: dim = len(shape) + dim + 1 shape.insert(dim, len(self._specs)) return torch.Size(shape)
[docs] def expand(self, *shape): if len(shape) == 1 and not isinstance(shape[0], (int,)): return self.expand(*shape[0]) expand_shape = shape[: -len(self.shape)] existing_shape = self.shape shape_check = shape[-len(self.shape) :] for _i, (size1, size2) in enumerate(zip(existing_shape, shape_check)): if size1 != size2 and size1 != 1: raise RuntimeError( f"Expanding a non-singletom dimension: existing shape={size1} vs expand={size2}" ) elif size1 != size2 and size1 == 1 and _i == self.dim: # if we're expanding along the stack dim we just need to clone the existing spec return torch.stack( [self._specs[0].clone() for _ in range(size2)], self.dim ).expand(*shape) if _i != len(self.shape) - 1: raise RuntimeError( f"Trying to expand non-congruent shapes: received {shape} when the shape is {self.shape}." ) # remove the stack dim from the expanded shape, which we know to match unstack_shape = list(expand_shape) + [ s for i, s in enumerate(shape_check) if i != self.dim ] return torch.stack( [spec.expand(unstack_shape) for spec in self._specs], self.dim + len(expand_shape), )
[docs] def empty(self): return torch.stack([spec.empty() for spec in self._specs], dim=self.stack_dim)
[docs] def encode( self, vals: Dict[str, Any], ignore_device: bool = False ) -> Dict[str, torch.Tensor]: raise NOT_IMPLEMENTED_ERROR
# for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]: @TensorSpec.implements_for_spec(torch.stack) def _stack_specs(list_of_spec, dim, out=None): if out is not None: raise NotImplementedError( "In-place spec modification is not a feature of torchrl, hence " "torch.stack(list_of_specs, dim, out=spec) is not implemented." ) if not len(list_of_spec): raise ValueError("Cannot stack an empty list of specs.") spec0 = list_of_spec[0] if isinstance(spec0, TensorSpec): device = spec0.device all_equal = True for spec in list_of_spec[1:]: if not isinstance(spec, spec0.__class__): raise RuntimeError( "Stacking specs cannot occur: Found more than one type of specs in the list." ) if device != spec.device: raise RuntimeError(f"Devices differ, got {device} and {spec.device}") if spec.dtype != spec0.dtype: raise RuntimeError(f"Dtypes differ, got {spec0.dtype} and {spec.dtype}") if spec.ndim != spec0.ndim: raise RuntimeError(f"Ndims differ, got {spec0.ndim} and {spec.ndim}") all_equal = all_equal and spec == spec0 if all_equal: shape = list(spec0.shape) if dim < 0: dim += len(shape) + 1 shape.insert(dim, len(list_of_spec)) return spec0.clone().unsqueeze(dim).expand(shape) return LazyStackedTensorSpec(*list_of_spec, dim=dim) else: raise NotImplementedError @CompositeSpec.implements_for_spec(torch.stack) def _stack_composite_specs(list_of_spec, dim, out=None): if out is not None: raise NotImplementedError( "In-place spec modification is not a feature of torchrl, hence " "torch.stack(list_of_specs, dim, out=spec) is not implemented." ) if not len(list_of_spec): raise ValueError("Cannot stack an empty list of specs.") spec0 = list_of_spec[0] if isinstance(spec0, CompositeSpec): devices = {spec.device for spec in list_of_spec} if len(devices) == 1: device = list(devices)[0] elif len(devices) == 2: device0, device1 = devices if device0 is None: device = device1 elif device1 is None: device = device0 else: device = None all_equal = True for spec in list_of_spec[1:]: if not isinstance(spec, CompositeSpec): raise RuntimeError( "Stacking specs cannot occur: Found more than one type of spec in " "the list." ) if device != spec.device and device is not None: # spec.device must be None spec = if spec.shape != spec0.shape: raise RuntimeError(f"Shapes differ, got {spec.shape} and {spec0.shape}") all_equal = all_equal and spec == spec0 if all_equal: shape = list(spec0.shape) if dim < 0: dim += len(shape) + 1 shape.insert(dim, len(list_of_spec)) return spec0.clone().unsqueeze(dim).expand(shape) return LazyStackedCompositeSpec(*list_of_spec, dim=dim) else: raise NotImplementedError @TensorSpec.implements_for_spec(torch.squeeze) def _squeeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: return spec.squeeze(*args, **kwargs) @CompositeSpec.implements_for_spec(torch.squeeze) def _squeeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec: return spec.squeeze(*args, **kwargs) @TensorSpec.implements_for_spec(torch.unsqueeze) def _unsqueeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: return spec.unsqueeze(*args, **kwargs) @CompositeSpec.implements_for_spec(torch.unsqueeze) def _unsqueeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec: return spec.unsqueeze(*args, **kwargs) def _keys_to_empty_composite_spec(keys): """Given a list of keys, creates a CompositeSpec tree where each leaf is assigned a None value.""" if not len(keys): return c = CompositeSpec() for key in keys: if isinstance(key, str): c[key] = None elif key[0] in c.keys(): if c[key[0]] is None: # if the value is None we just replace it c[key[0]] = _keys_to_empty_composite_spec([key[1:]]) elif isinstance(c[key[0]], CompositeSpec): # if the value is Composite, we update it out = _keys_to_empty_composite_spec([key[1:]]) if out is not None: c[key[0]].update(out) else: raise RuntimeError("Conflicting keys") else: c[key[0]] = _keys_to_empty_composite_spec(key[1:]) return c def _squeezed_shape(shape: torch.Size, dim: int | None) -> torch.Size | None: if dim is None: if len(shape) == 1 or shape.count(1) == 0: return None new_shape = torch.Size([s for s in shape if s != 1]) else: if dim < 0: dim += len(shape) if shape[dim] != 1: return None new_shape = torch.Size([s for i, s in enumerate(shape) if i != dim]) return new_shape def _unsqueezed_shape(shape: torch.Size, dim: int) -> torch.Size: n = len(shape) if dim < -(n + 1) or dim > n: raise ValueError( f"Dimension out of range, expected value in the range [{-(n+1)}, {n}], but " f"got {dim}" ) if dim < 0: dim += n + 1 new_shape = list(shape) new_shape.insert(dim, 1) return torch.Size(new_shape) class _CompositeSpecKeysView: """Wrapper class that enables richer behaviour of `key in tensordict.keys()`.""" def __init__( self, composite: CompositeSpec, include_nested, leaves_only, ): self.composite = composite self.leaves_only = leaves_only self.include_nested = include_nested def __iter__(self): for key, item in self.composite.items(): if self.include_nested and isinstance(item, CompositeSpec): for subkey in item.keys( include_nested=True, leaves_only=self.leaves_only ): if not isinstance(subkey, tuple): subkey = (subkey,) yield (key, *subkey) if not self.leaves_only: yield key elif not isinstance(item, CompositeSpec) or not self.leaves_only: yield key def __len__(self): i = 0 for _ in self: i += 1 return i def __repr__(self): return f"_CompositeSpecKeysView(keys={list(self)})" def __contains__(self, item): item = unravel_key(item) if len(item) == 1: item = item[0] for key in self.__iter__(): if key == item: return True else: return False def _minmax_dtype(dtype): if dtype is torch.bool: return False, True if dtype.is_floating_point: info = torch.finfo(dtype) else: info = torch.iinfo(dtype) return info.min, info.max


