Shortcuts

Source code for torchrl.data.utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import typing
from typing import Any, Callable, List, Tuple, Union

import numpy as np
import torch

from torch import Tensor

from torchrl.data.tensor_specs import (
    Binary,
    Categorical,
    Composite,
    MultiCategorical,
    MultiOneHot,
    OneHot,
    Stacked,
    StackedComposite,
    TensorSpec,
)

numpy_to_torch_dtype_dict = {
    np.dtype("bool"): torch.bool,
    np.dtype("uint8"): torch.uint8,
    np.dtype("int8"): torch.int8,
    np.dtype("int16"): torch.int16,
    np.dtype("int32"): torch.int32,
    np.dtype("int64"): torch.int64,
    np.dtype("float16"): torch.float16,
    np.dtype("float32"): torch.float32,
    np.dtype("float64"): torch.float64,
    np.dtype("complex64"): torch.complex64,
    np.dtype("complex128"): torch.complex128,
}
torch_to_numpy_dtype_dict = {
    value: key for key, value in numpy_to_torch_dtype_dict.items()
}
DEVICE_TYPING = Union[torch.device, str, int]
if hasattr(typing, "get_args"):
    DEVICE_TYPING_ARGS = typing.get_args(DEVICE_TYPING)
else:
    DEVICE_TYPING_ARGS = (torch.device, str, int)

INDEX_TYPING = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]]


ACTION_SPACE_MAP = {
    OneHot: "one_hot",
    MultiOneHot: "mult_one_hot",
    Binary: "binary",
    Categorical: "categorical",
    "one_hot": "one_hot",
    "one-hot": "one_hot",
    "mult_one_hot": "mult_one_hot",
    "mult-one-hot": "mult_one_hot",
    "multi_one_hot": "mult_one_hot",
    "multi-one-hot": "mult_one_hot",
    "binary": "binary",
    "categorical": "categorical",
    MultiCategorical: "multi_categorical",
    "multi_categorical": "multi_categorical",
    "multi-categorical": "multi_categorical",
    "multi_discrete": "multi_categorical",
    "multi-discrete": "multi_categorical",
}


[docs]def consolidate_spec( spec: Composite, recurse_through_entries: bool = True, recurse_through_stack: bool = True, ): """Given a TensorSpec, removes exclusive keys by adding 0 shaped specs. Args: spec (Composite): the spec to be consolidated. recurse_through_entries (bool): if True, call the function recursively on all entries of the spec. Default is True. recurse_through_stack (bool): if True, if the provided spec is lazy, the function recursively on all specs in its list. Default is True. """ spec = spec.clone() if not isinstance(spec, (Composite, StackedComposite)): return spec if isinstance(spec, StackedComposite): keys = set(spec.keys()) # shared keys exclusive_keys_per_spec = [ set() for _ in range(len(spec._specs)) ] # list of exclusive keys per td exclusive_keys_examples = ( {} ) # map of all exclusive keys to a list of their values for spec_index in range(len(spec._specs)): # gather all exclusive keys sub_spec = spec._specs[spec_index] if recurse_through_stack: sub_spec = consolidate_spec( sub_spec, recurse_through_entries, recurse_through_stack ) spec._specs[spec_index] = sub_spec for sub_spec_key in sub_spec.keys(): if sub_spec_key not in keys: # exclusive key exclusive_keys_per_spec[spec_index].add(sub_spec_key) value = sub_spec[sub_spec_key] if sub_spec_key in exclusive_keys_examples: exclusive_keys_examples[sub_spec_key].append(value) else: exclusive_keys_examples.update({sub_spec_key: [value]}) for sub_spec, exclusive_keys in zip( spec._specs, exclusive_keys_per_spec ): # add missing exclusive entries for exclusive_key in set(exclusive_keys_examples.keys()).difference( exclusive_keys ): exclusive_keys_example_list = exclusive_keys_examples[exclusive_key] sub_spec.set( exclusive_key, _empty_like_spec(exclusive_keys_example_list, sub_spec.shape), ) if recurse_through_entries: for key, value in spec.items(): if isinstance(value, (Composite, StackedComposite)): spec.set( key, consolidate_spec( value, recurse_through_entries, recurse_through_stack ), ) return spec
def _empty_like_spec(specs: List[TensorSpec], shape): for spec in specs[1:]: if spec.__class__ != specs[0].__class__: raise ValueError( "Found same key in lazy specs corresponding to entries with different classes" ) spec = specs[0] if isinstance(spec, (Composite, StackedComposite)): # the exclusive key has values which are CompositeSpecs -> # we create an empty composite spec with same batch size return spec.empty() elif isinstance(spec, Stacked): # the exclusive key has values which are LazyStackedTensorSpecs -> # we create a LazyStackedTensorSpec with the same shape (aka same -1s) as the first in the list. # this will not add any new -1s when they are stacked shape = list(shape[: spec.stack_dim]) + list(shape[spec.stack_dim + 1 :]) return Stacked( *[_empty_like_spec(spec._specs, shape) for _ in spec._specs], dim=spec.stack_dim, ) else: # the exclusive key has values which are TensorSpecs -> # if the shapes of the values are all the same, we create a TensorSpec with leading shape `shape` and following dims 0 (having the same ndims as the values) # if the shapes of the values differ, we create a TensorSpec with 0 size in the differing dims spec_shape = list(spec.shape) for dim_index in range(len(spec_shape)): hetero_dim = False for sub_spec in specs: if sub_spec.shape[dim_index] != spec.shape[dim_index]: hetero_dim = True break if hetero_dim: spec_shape[dim_index] = 0 if 0 not in spec_shape: # the values have all same shape spec_shape = [ dim if i < len(shape) else 0 for i, dim in enumerate(spec_shape) ] spec = spec[(0,) * len(spec.shape)] spec = spec.expand(spec_shape) return spec
[docs]def check_no_exclusive_keys(spec: TensorSpec, recurse: bool = True): """Given a TensorSpec, returns true if there are no exclusive keys. Args: spec (TensorSpec): the spec to check recurse (bool): if True, check recursively in nested specs. Default is True. """ if isinstance(spec, StackedComposite): keys = set(spec.keys()) for inner_td in spec._specs: if recurse and not check_no_exclusive_keys(inner_td): return False if set(inner_td.keys()) != keys: return False elif isinstance(spec, Composite) and recurse: for value in spec.values(): if not check_no_exclusive_keys(value): return False else: return True return True
[docs]def contains_lazy_spec(spec: TensorSpec) -> bool: """Returns true if a spec contains lazy stacked specs. Args: spec (TensorSpec): the spec to check """ if isinstance(spec, (Stacked, StackedComposite)): return True elif isinstance(spec, Composite): for inner_spec in spec.values(): if contains_lazy_spec(inner_spec): return True return False
class CloudpickleWrapper(object): """A wrapper for functions that allow for serialization in multiprocessed settings.""" def __init__(self, fn: Callable, **kwargs): if fn.__class__.__name__ == "EnvCreator": raise RuntimeError( "CloudpickleWrapper usage with EnvCreator class is " "prohibited as it breaks the transmission of shared tensors." ) self.fn = fn self.kwargs = kwargs functools.update_wrapper(self, getattr(fn, "forward", fn)) def __getstate__(self): import cloudpickle return cloudpickle.dumps((self.fn, self.kwargs)) def __setstate__(self, ob: bytes): import pickle self.fn, self.kwargs = pickle.loads(ob) functools.update_wrapper(self, self.fn) def __call__(self, *args, **kwargs) -> Any: kwargs.update(self.kwargs) return self.fn(*args, **kwargs) def _process_action_space_spec(action_space, spec): original_spec = spec composite_spec = False if isinstance(spec, Composite): # this will break whenever our action is more complex than a single tensor try: if "action" in spec.keys(): _key = "action" else: # the first key is the action for _key in spec.keys(True, True): if isinstance(_key, tuple) and _key[-1] == "action": break else: raise KeyError spec = spec[_key] composite_spec = True except KeyError: raise KeyError( "action could not be found in the spec. Make sure " "you pass a spec that is either a native action spec or a composite action spec " "with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only." ) if action_space is not None: if isinstance(action_space, Composite): raise ValueError("action_space cannot be of type Composite.") if ( spec is not None and isinstance(action_space, TensorSpec) and action_space is not spec ): raise ValueError( "Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match." ) if isinstance(action_space, TensorSpec): spec = action_space action_space = _find_action_space(action_space) # check that the spec and action_space match if spec is not None and _find_action_space(spec) != action_space: raise ValueError( f"The action spec and the action space do not match: got action_space={action_space} and spec={spec}." ) elif spec is not None: action_space = _find_action_space(spec) else: raise ValueError( "Neither action_space nor spec was defined. The action space cannot be inferred." ) if composite_spec: spec = original_spec return action_space, spec def _find_action_space(action_space) -> str: if isinstance(action_space, TensorSpec): if isinstance(action_space, Composite): if "action" in action_space.keys(): _key = "action" else: # the first key is the action for _key in action_space.keys(True, True): if isinstance(_key, tuple) and _key[-1] == "action": break else: raise KeyError action_space = action_space[_key] action_space = type(action_space) try: action_space = ACTION_SPACE_MAP[action_space] except KeyError: raise ValueError( f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}." ) return action_space

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources