Source code for torch.nn.utils.parametrize

import torch
from torch.nn.modules.container import ModuleList, ModuleDict, Module
from torch.nn.parameter import Parameter
from torch import Tensor
from typing import Union, Optional, Iterable, Dict, Tuple
from contextlib import contextmanager

_cache_enabled = 0
_cache: Dict[Tuple[int, str], Optional[Tensor]] = {}

[docs]@contextmanager def cached(): r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`. The value of the parametrized objects is computed and cached the first time they are required when this context manager is active. The cached values are discarded when leaving the context manager. This is useful when using a parametrized parameter more than once in the forward pass. An example of this is when parametrizing the recurrent kernel of an RNN or when sharing weights. The simplest way to activate the cache is by wrapping the forward pass of the neural network .. code-block:: python import torch.nn.utils.parametrize as P ... with P.cached(): output = model(inputs) in training and evaluation. One may also wrap the parts of the modules that use several times the parametrized tensors. For example, the loop of an RNN with a parametrized recurrent kernel: .. code-block:: python with P.cached(): for x in xs: out_rnn = self.rnn_cell(x, out_rnn) """ global _cache global _cache_enabled _cache_enabled += 1 try: yield finally: _cache_enabled -= 1 if not _cache_enabled: _cache = {}
class ParametrizationList(ModuleList): r"""A sequential container that holds and manages the ``original`` parameter or buffer of a parametrized :class:`torch.nn.Module`. It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` has been parametrized with :func:`register_parametrization`. .. note :: This class is used internally by :func:`register_parametrization`. It is documented here for completeness. It should not be instantiated by the user. Args: modules (iterable): an iterable of modules representing the parametrizations original (Parameter or Tensor): parameter or buffer that is parametrized """ original: Tensor def __init__( self, modules: Iterable[Module], original: Union[Tensor, Parameter] ) -> None: super().__init__(modules) if isinstance(original, Parameter): self.register_parameter("original", original) else: self.register_buffer("original", original) def set_original_(self, value: Tensor) -> None: r"""This method is called when assigning to a parametrized tensor. It calls the methods ``right_inverse`` (see :func:`register_parametrization`) of the parametrizations in the inverse order that they have been registered. Then, it assigns the result to ``self.original``. Args: value (Tensor): Value to which initialize the module Raises: RuntimeError: if any of the parametrizations do not implement a ``right_inverse`` method """ with torch.no_grad(): # See for module in reversed(self): # type: ignore[call-overload] if hasattr(module, "right_inverse"): value = module.right_inverse(value) else: raise RuntimeError( "The parametrization '{}' does not implement a 'right_inverse' method. " "Assigning to a parametrized tensor is only possible when all the parametrizations " "implement a 'right_inverse' method.".format(module.__class__.__name__) ) self.original.copy_(value) def forward(self) -> Tensor: x = self.original for module in self: x = module(x) if x.size() != self.original.size(): raise RuntimeError( "The parametrization may not change the size of the parametrized tensor. " "Size of original tensor: {} " "Size of parametrized tensor: {}".format(self.original.size(), x.size()) ) return x def _inject_new_class(module: Module) -> None: r"""Sets up the parametrization mechanism used by parametrizations. This works by substituting the class of the module by a class that extends it to be able to inject a property Args: module (nn.Module): module into which to inject the property """ cls = module.__class__ def getstate(self): raise RuntimeError( "Serialization of parametrized modules is only " "supported through state_dict(). See:\n" "" "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" ) param_cls = type( "Parametrized{}".format(cls.__name__), (cls,), { "__getstate__": getstate, }, ) module.__class__ = param_cls def _inject_property(module: Module, tensor_name: str) -> None: r"""Injects a property into module[tensor_name]. It assumes that the class in the module has already been modified from its original one using _inject_new_class and that the tensor under :attr:`tensor_name` has already been moved out Args: module (nn.Module): module into which to inject the property tensor_name (str): name of the name of the property to create """ # We check the precondition. # This should never fire if register_parametrization is correctly implemented assert not hasattr(module, tensor_name) def get_parametrized(self) -> Tensor: global _cache parametrization = self.parametrizations[tensor_name] if _cache_enabled: key = (id(module), tensor_name) tensor = _cache.get(key) if tensor is None: tensor = parametrization() _cache[key] = tensor return tensor else: # If caching is not active, this function just evaluates the parametrization return parametrization() def set_original(self, value: Tensor) -> None: self.parametrizations[tensor_name].set_original_(value) setattr(module.__class__, tensor_name, property(get_parametrized, set_original))
[docs]def register_parametrization( module: Module, tensor_name: str, parametrization: Module ) -> Module: r"""Adds a parametrization to a tensor in a module. Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, the module will return the parametrized version ``parametrization(module.weight)``. If the original tensor requires a gradient, the backward pass will differentiate through the :attr:`parametrization`, and the optimizer will update the tensor accordingly. The first time that a module registers a parametrization, this function will add an attribute ``parametrizations`` to the module of type :class:`~ParametrizationList`. The list of parametrizations on a tensor will be accessible under ``module.parametrizations.weight``. The original tensor will be accessible under ``module.parametrizations.weight.original``. Parametrizations may be concatenated by registering several parametrizations on the same attribute. The training mode of the registered parametrizations are updated on registration if necessary to match the training mode of the host module Parametrized parameters and buffers have an inbuilt caching system that can be activated using the context manager :func:`cached`. A :attr:`parametrization` may optionally implement a method with signature .. code-block:: python def right_inverse(self, X: Tensor) -> Tensor If :attr:`parametrization` implements this method, it will be possible to assign to the parametrized tensor. This may be used to initialize the tensor, as shown in the example. In most situations, ``right_inverse`` will be a function such that ``forward(right_inverse(X)) == X`` (see `right inverse <>`_). Sometimes, when the parametrization is not surjective, it may be reasonable to relax this, as shown in the example below. Args: module (nn.Module): module on which to register the parametrization tensor_name (str): name of the parameter or buffer on which to register the parametrization parametrization (nn.Module): the parametrization to register Returns: Module: module Raises: ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` Examples: >>> import torch >>> import torch.nn.utils.parametrize as P >>> >>> class Symmetric(torch.nn.Module): >>> def forward(self, X): >>> return X.triu() + X.triu(1).T # Return a symmetric matrix >>> >>> def right_inverse(self, A): >>> return A.triu() >>> >>> m = torch.nn.Linear(5, 5) >>> P.register_parametrization(m, "weight", Symmetric()) >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric True >>> A = torch.rand(5, 5) >>> A = A + A.T # A is now symmetric >>> m.weight = A # Initialize the weight to be the symmetric matrix A >>> print(torch.allclose(m.weight, A)) True """ parametrization.train( if is_parametrized(module, tensor_name): # Just add the new parametrization to the parametrization list module.parametrizations[tensor_name].append(parametrization) # type: ignore[index, union-attr] elif tensor_name in module._buffers or tensor_name in module._parameters: # Set the parametrization mechanism # Fetch the original buffer or parameter original = getattr(module, tensor_name) # Delete the previous parameter or buffer delattr(module, tensor_name) # If this is the first parametrization registered on the module, # we prepare the module to inject the property if not is_parametrized(module): # Change the class _inject_new_class(module) # Inject the a ``ModuleDict`` into the instance under module.parametrizations module.parametrizations = ModuleDict() # Add a property into the class _inject_property(module, tensor_name) # Add a ParametrizationList module.parametrizations[tensor_name] = ParametrizationList( # type: ignore[assignment, index, operator] [parametrization], original ) else: raise ValueError( "Module '{}' does not have a parameter, a buffer, or a " "parametrized element with name '{}'".format(module, tensor_name) ) return module
[docs]def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: r"""Returns ``True`` if module has an active parametrization. If the argument :attr:`tensor_name` is specified, returns ``True`` if ``module[tensor_name]`` is parametrized. Args: module (nn.Module): module to query name (str, optional): attribute in the module to query Default: ``None`` """ parametrizations = getattr(module, "parametrizations", None) if parametrizations is None or not isinstance(parametrizations, ModuleDict): return False if tensor_name is None: # Check that there is at least one parametrized buffer or Parameter return len(parametrizations) > 0 else: return tensor_name in parametrizations
[docs]def remove_parametrizations( module: Module, tensor_name: str, leave_parametrized: bool = True ) -> Module: r"""Removes the parametrizations on a tensor in a module. - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to its current output. In this case, the parametrization shall not change the ``dtype`` of the tensor. - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to the unparametrised tensor in ``module.parametrizations[tensor_name].original``. Args: module (nn.Module): module from which remove the parametrization tensor_name (str): name of the parametrization to be removed leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. Default: ``True`` Returns: Module: module Raises: ValueError: if ``module[tensor_name]`` is not parametrized ValueError: if ``leave_parametrized=True`` and the parametrization changes the size or dtype of the tensor """ if not is_parametrized(module, tensor_name): raise ValueError( "Module {} does not have a parametrization on {}".format( module, tensor_name ) ) # Fetch the original tensor original = module.parametrizations[tensor_name].original # type: ignore[index, union-attr] if leave_parametrized: with torch.no_grad(): t = getattr(module, tensor_name) # If they have the same dtype, we reuse the original tensor. # We do this so that the parameter does not to change the id() # This way the user does not need to update the optimizer if t.dtype == original.dtype: with torch.no_grad(): original.set_(t) else: raise ValueError( "The parametrization changes the dtype of the tensor from {} to {}. " "It is not supported to leave the tensor parametrized (`leave_parametrized=True`) " "in this case.".format(original.dtype, t.dtype) ) # Delete the property that manages the parametrization delattr(module.__class__, tensor_name) # Delete the ParametrizationList del module.parametrizations[tensor_name] # type: ignore[operator, union-attr] # Restore the parameter / buffer into the main class if isinstance(original, Parameter): module.register_parameter(tensor_name, original) else: module.register_buffer(tensor_name, original) # Roll back the parametrized class if no other buffer or parameter # is currently parametrized in this class if not is_parametrized(module): delattr(module, "parametrizations") # Restore class orig_cls = module.__class__.__bases__[0] module.__class__ = orig_cls return module


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources