Source code for tensordict.nn.utils
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import functools
import inspect
import os
from enum import Enum
from typing import Any, Callable
import torch
from tensordict.utils import _ContextManager, strtobool
from torch import nn
from torch.utils._contextlib import _DecoratorContextManager
try:
from torch.compiler import is_compiling
except ImportError: # torch 2.0
from torch._dynamo import is_compiling
_dispatch_tdnn_modules = _ContextManager(
default=strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True"))
)
__all__ = ["mappings", "inv_softplus", "biased_softplus"]
_skip_existing = _ContextManager(default=False)
[docs]def inv_softplus(bias: float | torch.Tensor) -> float | torch.Tensor:
"""Inverse softplus function.
Args:
bias (float or tensor): the value to be softplus-inverted.
"""
is_tensor = True
if not isinstance(bias, torch.Tensor):
is_tensor = False
bias = torch.tensor(bias)
out = bias.expm1().clamp_min(1e-6).log()
if not is_tensor and out.numel() == 1:
return out.item()
return out
[docs]class biased_softplus(nn.Module):
"""A biased softplus module.
The bias indicates the value that is to be returned when a zero-tensor is
passed through the transform.
Args:
bias (scalar): 'bias' of the softplus transform. If bias=1.0, then a _bias shift will be computed such that
softplus(0.0 + _bias) = bias.
min_val (scalar): minimum value of the transform.
default: 0.1
"""
def __init__(self, bias: float, min_val: float = 0.01) -> None:
super().__init__()
self.bias = inv_softplus(bias - min_val)
self.min_val = min_val
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softplus(x + self.bias) + self.min_val
_MAPPINGS: dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
"softplus": torch.nn.functional.softplus,
"exp": torch.exp,
"relu": torch.relu,
"biased_softplus": biased_softplus(1.0),
"none": lambda x: x,
}
def mappings(key: str) -> Callable:
"""Given an input string, returns a surjective function f(x): R -> R^+.
Args:
key (str): one of `"softplus"`, `"exp"`, `"relu"`, `"expln"`,
`"biased_softplus"` or `"none"` (no mapping).
.. note::
If the key begins with `"biased_softplus"`, then it needs to take the following form:
```"biased_softplus_{bias}"``` where ```bias``` can be converted to a floating point number that will be
used to bias the softplus function.
Alternatively, the ```"biased_softplus_{bias}_{min_val}"``` syntax can be used.
In that case, the additional ```min_val``` term is a floating point
number that will be used to encode the minimum value of the softplus transform.
In practice, the equation used is `softplus(x + bias) + min_val`, where bias and min_val are values computed
such that the conditions above are met.
.. note::
Custom mappings can be added through ``tensordict.nn.add_custom_mapping``.
Returns:
a Callable
"""
if key in _MAPPINGS:
return _MAPPINGS[key]
elif key.startswith("biased_softplus"):
stripped_key = key.split("_")
if len(stripped_key) == 3:
return biased_softplus(float(stripped_key[-1]))
elif len(stripped_key) == 4:
return biased_softplus(
float(stripped_key[-2]), min_val=float(stripped_key[-1])
)
else:
raise ValueError(f"Invalid number of args in {key}")
else:
raise NotImplementedError(f"Unknown mapping {key}")
def add_custom_mapping(name: str, mapping: Callable[[torch.Tensor], torch.Tensor]):
"""Adds a custom mapping to be used in mapping classes.
Args:
name (str): a mapping name.
mapping (callable): a callable that takes a tensor as input and outputs a tensor
with the same shape.
Examples:
>>> from tensordict.nn import add_custom_mapping, NormalParamExtractor
>>> add_custom_mapping("my_mapping", lambda x: torch.zeros_like(x))
>>> npe = NormalParamExtractor(scale_mapping="my_mapping", scale_lb=0.0)
>>> assert (npe(torch.randn(10))[1] == torch.zeros(5)).all()
"""
_MAPPINGS[name] = mapping
class set_skip_existing(_DecoratorContextManager):
"""A context manager for skipping existing nodes in a TensorDict graph.
When used as a context manager, it will set the `skip_existing()` value
to the ``mode`` indicated, leaving the user able to code up methods that
will check the global value and execute the code accordingly.
When used as a method decorator, it will check the tensordict input keys
and if the ``skip_existing()`` call returns ``True``, it will skip the method
if all the output keys are already present.
This not not expected to be used as a decorator for methods that do not
respect the following signature: ``def fun(self, tensordict, *args, **kwargs)``.
Args:
mode (bool, optional):
If ``True``, it indicates that existing entries in the graph
won't be overwritten, unless they are only partially present. :func:`~.skip_existing`
will return ``True``.
If ``False``, no check will be performed.
If ``None``, the value of :func:`~.skip_existing` will not be
changed. This is intended to be used exclusively for decorating
methods and allow their behaviour to depend on the same class
when used as a context manager (see example below).
Defaults to ``True``.
in_key_attr (str, optional): the name of the input key list attribute
in the module's method being decorated. Defaults to ``in_keys``.
out_key_attr (str, optional): the name of the output key list attribute
in the module's method being decorated. Defaults to ``out_keys``.
Examples:
>>> with set_skip_existing():
... if skip_existing():
... print("True")
... else:
... print("False")
...
True
>>> print("calling from outside:", skip_existing())
calling from outside: False
This class can also be used as a decorator:
Examples:
>>> from tensordict import TensorDict
>>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase
>>> class MyModule(TensorDictModuleBase):
... in_keys = []
... out_keys = ["out"]
... @set_skip_existing()
... def forward(self, tensordict):
... print("hello")
... tensordict.set("out", torch.zeros(()))
... return tensordict
>>> module = MyModule()
>>> module(TensorDict({"out": torch.zeros(())}, [])) # does not print anything
TensorDict(
fields={
out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> module(TensorDict()) # prints hello
hello
TensorDict(
fields={
out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
Decorating a method with the mode set to ``None`` is useful whenever one
wants ot let the context manager take care of skipping things from the outside:
Examples:
>>> from tensordict import TensorDict
>>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase
>>> class MyModule(TensorDictModuleBase):
... in_keys = []
... out_keys = ["out"]
... @set_skip_existing(None)
... def forward(self, tensordict):
... print("hello")
... tensordict.set("out", torch.zeros(()))
... return tensordict
>>> module = MyModule()
>>> _ = module(TensorDict({"out": torch.zeros(())}, [])) # prints "hello"
hello
>>> with set_skip_existing(True):
... _ = module(TensorDict({"out": torch.zeros(())}, [])) # no print
.. note::
To allow for modules to have the same input and output keys and not
mistakenly ignoring subgraphs, ``@set_skip_existing(True)`` will be
deactivated whenever the output keys are also the input keys:
>>> class MyModule(TensorDictModuleBase):
... in_keys = ["out"]
... out_keys = ["out"]
... @set_skip_existing()
... def forward(self, tensordict):
... print("calling the method!")
... return tensordict
...
>>> module = MyModule()
>>> module(TensorDict({"out": torch.zeros(())}, [])) # does not print anything
calling the method!
TensorDict(
fields={
out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""
def __init__(
self, mode: bool | None = True, in_key_attr="in_keys", out_key_attr="out_keys"
):
self.mode = mode
self.in_key_attr = in_key_attr
self.out_key_attr = out_key_attr
self._called = False
def clone(self) -> set_skip_existing:
# override this method if your children class takes __init__ parameters
out = type(self)(self.mode)
out._called = self._called
return out
def __call__(self, func: Callable):
self._called = True
# sanity check
for i, key in enumerate(inspect.signature(func).parameters):
if i == 0:
# skip self
continue
if key != "tensordict":
raise RuntimeError(
"the first argument of the wrapped function must be "
"named 'tensordict'."
)
break
@functools.wraps(func)
def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
in_keys = getattr(_self, self.in_key_attr)
out_keys = getattr(_self, self.out_key_attr)
# we use skip_existing to allow users to override the mode internally
if (
skip_existing()
and all(key in tensordict.keys(True) for key in out_keys)
and not any(key in out_keys for key in in_keys)
):
return tensordict
return func(_self, tensordict, *args, **kwargs)
return super().__call__(wrapper)
def __enter__(self) -> None:
if self.mode and is_compiling():
raise RuntimeError("skip_existing is not compatible with TorchDynamo.")
self.prev = _skip_existing.get_mode()
if self.mode is not None:
_skip_existing.set_mode(self.mode)
elif not self._called:
raise RuntimeError(
f"It seems you are using {type(self).__name__} as a context manager with ``None`` input. "
f"This behaviour is not allowed."
)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
_skip_existing.set_mode(self.prev)
class _set_skip_existing_None(set_skip_existing):
"""A version of skip_existing that is constant wrt init inputs (for torch.compile compatibility).
This class should only be used as a decorator, not a context manager.
"""
def __call__(self, func: Callable):
self._called = True
# sanity check
for i, key in enumerate(inspect.signature(func).parameters):
if i == 0:
# skip self
continue
if key != "tensordict":
raise RuntimeError(
"the first argument of the wrapped function must be "
"named 'tensordict'."
)
break
@functools.wraps(func)
def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
if skip_existing() and is_compiling():
raise RuntimeError(
"skip_existing is not compatible with torch.compile."
)
in_keys = getattr(_self, self.in_key_attr)
out_keys = getattr(_self, self.out_key_attr)
# we use skip_existing to allow users to override the mode internally
if (
skip_existing()
and all(key in tensordict.keys(True) for key in out_keys)
and not any(key in out_keys for key in in_keys)
):
return tensordict
if is_compiling():
return func(_self, tensordict, *args, **kwargs)
self.prev = _skip_existing.get_mode()
try:
result = func(_self, tensordict, *args, **kwargs)
finally:
_skip_existing.set_mode(self.prev)
return result
return wrapper
in_key_attr = "in_keys"
out_key_attr = "out_keys"
__init__ = object.__init__
def clone(self) -> _set_skip_existing_None:
# override this method if your children class takes __init__ parameters
out = type(self)()
return out
def skip_existing():
"""Returns whether or not existing entries in a tensordict should be re-computed by a module."""
return _skip_existing.get_mode()
def _rebuild_buffer(data, requires_grad, backward_hooks):
buffer = Buffer(data, requires_grad)
# NB: This line exists only for backwards compatibility; the
# general expectation is that backward_hooks is an empty
# OrderedDict. See Note [Don't serialize hooks]
buffer._backward_hooks = backward_hooks
return buffer
# For backward compatibility in imports
try:
from torch.nn.parameter import Buffer # noqa
except ImportError:
from tensordict.utils import Buffer # noqa
def _dispatch_td_nn_modules():
"""Returns ``True`` if @dispatch should be used. Not using dispatch is faster and also better compatible with torch.compile."""
return _dispatch_tdnn_modules.get_mode()
class _set_dispatch_td_nn_modules(_DecoratorContextManager):
"""Controls whether @dispatch should be used. Not using dispatch is faster and also better compatible with torch.compile."""
def __init__(self, mode):
self.mode = mode
self._saved_mode = None
def clone(self):
return type(self)(self.mode)
def __enter__(self):
# We want to avoid changing global variables because compile puts guards on them
if _dispatch_tdnn_modules.get_mode() != self.mode:
self._saved_mode = _dispatch_tdnn_modules
_dispatch_tdnn_modules.set_mode(self.mode)
def __exit__(self, exc_type, exc_val, exc_tb):
if self._saved_mode is None:
return
_dispatch_tdnn_modules.set_mode(self._saved_mode)
# Reproduce StrEnum for python<3.11
class StrEnum(str, Enum): # noqa
def __new__(cls, *values):
if len(values) > 3:
raise TypeError("too many arguments for str(): %r" % (values,))
if len(values) == 1:
# it must be a string
if not isinstance(values[0], str):
raise TypeError("%r is not a string" % (values[0],))
if len(values) >= 2:
# check that encoding argument is a string
if not isinstance(values[1], str):
raise TypeError("encoding must be a string, not %r" % (values[1],))
if len(values) == 3:
# check that errors argument is a string
if not isinstance(values[2], str):
raise TypeError("errors must be a string, not %r" % (values[2]))
value = str(*values)
member = str.__new__(cls, value)
member._value_ = value
return member
def _generate_next_value_(name, start, count, last_values):
return name.lower()