Source code for torch.nn.utils.stateless

import contextlib
from typing import Any, Callable, Dict, Iterator, List, Tuple

import torch
from torch import Tensor

__all__ = ["functional_call"]

# We avoid typing module here because module attributes are declared as Union[Parameter, Tensor] by default
# and using other types causes mypy errors
def _change_class(module, params_and_buffers) -> None:
    cls = module.__class__
    attr_to_path : Dict[str, str] = module._attr_to_path

    def _getattribute(self, name: str) -> Any:
        if name in attr_to_path:
            return params_and_buffers[attr_to_path[name]]
        return cls.__getattribute__(self, name)

    def _setattr(self, name: str, value: Any) -> None:
        if name in attr_to_path:
            params_and_buffers[attr_to_path[name]] = value
            return cls.__setattr__(self, name, value)

    param_cls = type(
            "__getattribute__": _getattribute,
            "__setattr__": _setattr,

    module.__class__ = param_cls
    module._orig_class = cls

def _create_swap_params(params_and_buffers):
    def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
        # Changes the module class to get a new __getattr__ dunder method
        # that looks for the reparametrized tensor
        if hasattr(module, "_attr_to_path"):
            module._attr_to_path[tensor_name] = full_path
            module._attr_to_path = {}
            module._attr_to_path[tensor_name] = full_path
            _change_class(module, params_and_buffers)
    return _swap_parameters

def _remove_swap(module, name: str, full_path: str) -> None:
    if hasattr(module, "_orig_class"):
        module.__class__ = module._orig_class
        delattr(module, "_orig_class")
        delattr(module, "_attr_to_path")

def _reparametrize_module(
    module: 'torch.nn.Module',
    parameters_and_buffers: Dict[str, Tensor],
) -> Iterator[None]:
    for name, tensor in parameters_and_buffers.items():
            module, name.split("."), name, (tensor,))
        for name in parameters_and_buffers:
                module, name.split("."), name, ())

def _apply_func_submodules(
    func: Callable[..., None],
    module: 'torch.nn.Module',
    path: List[str],
    full_path: str,
    args: Tuple,
    if len(path) == 1:
        func(module, path[0], full_path, *args)
        _apply_func_submodules(func, getattr(module, path[0]), path[1:], full_path, args)

[docs]def functional_call( module: 'torch.nn.Module', parameters_and_buffers: Dict[str, Tensor], args: Tuple, kwargs : Dict[str, Any] = None, ): r"""Performs a functional call on the module by replacing the module parameters and buffers with the provided ones. .. note:: If the module has active parametrizations, passing a value in the :attr:`parameters_and_buffers` argument with the name set to the regular parameter name will completely disable the parametrization. If you want to apply the parametrization function to the value passed please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``. .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected in the `parameters_and_buffers` input. Example:: >>> a = {'foo': torch.zeros(())} >>> mod = Foo() # does = + 1 >>> print( # tensor(0.) >>> functional_call(mod, a, torch.ones(())) >>> print( # tensor(0.) >>> print(a['foo']) # tensor(1.) Args: module (torch.nn.Module): the module to call parameters_and_buffers (dict of str and Tensor): the parameters that will be used in the module call. args (tuple): arguments to be passed to the module call kwargs (dict): keyword arguments to be passed to the module call Returns: Any: the result of calling ``module``. """ # TODO allow kwargs such as unsafe and others for parametrization if ( torch.jit.is_tracing() or torch.jit.is_scripting() or isinstance(module, ( torch.jit.RecursiveScriptModule, torch.jit.ScriptModule, torch.jit.ScriptFunction) ) ): raise RuntimeError("The stateless API can't be used with Jitted modules") if kwargs is None: kwargs = {} with _reparametrize_module(module, parameters_and_buffers): if isinstance(args, tuple): out = module(*args, **kwargs) else: out = module(args, **kwargs) return out


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources