• Docs >
  • torch.nn >
  • torch.nn.utils.parametrize.register_parametrization
Shortcuts

torch.nn.utils.parametrize.register_parametrization

torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization)[source]

Adds a parametrization to a tensor in a module.

Assume that tensor_name="weight" for simplicity. When accessing module.weight, the module will return the parametrized version parametrization(module.weight). If the original tensor requires a gradient, the backward pass will differentiate through the parametrization, and the optimizer will update the tensor accordingly.

The first time that a module registers a parametrization, this function will add an attribute parametrizations to the module of type ParametrizationList.

The list of parametrizations on a tensor will be accessible under module.parametrizations.weight.

The original tensor will be accessible under module.parametrizations.weight.original.

Parametrizations may be concatenated by registering several parametrizations on the same attribute.

Parametrized parameters and buffers have an inbuilt caching system that can be activated using the context manager cached().

A parametrization may optionally implement a method with signature

def right_inverse(self, X: Tensor) -> Tensor

If parametrization implements this method, it will be possible to assign to the parametrized tensor. This may be used to initialize the tensor, as shown in the example.

In most situations, right_inverse will be a function such that forward(right_inverse(X)) == X (see right inverse). Sometimes, when the parametrization is not surjective, it may be reasonable to relax this, as shown in the example below.

Parameters
  • module (nn.Module) – module on which to register the parametrization

  • tensor_name (str) – name of the parameter or buffer on which to register the parametrization

  • parametrization (nn.Module) – the parametrization to register

Returns

module

Return type

Module

Raises

ValueError – if the module does not have a parameter or a buffer named tensor_name

Examples

>>> import torch
>>> import torch.nn.utils.parametrize as P
>>>
>>> class Symmetric(torch.nn.Module):
>>>     def forward(self, X):
>>>         return X.triu() + X.triu(1).T  # Return a symmetric matrix
>>>
>>>     def right_inverse(self, A):
>>>         return A.triu()
>>>
>>> m = torch.nn.Linear(5, 5)
>>> P.register_parametrization(m, "weight", Symmetric())
>>> print(torch.allclose(m.weight, m.weight.T))  # m.weight is now symmetric
True
>>> A = torch.rand(5, 5)
>>> A = A + A.T   # A is now symmetric
>>> m.weight = A  # Initialize the weight to be the symmetric matrix A
>>> print(torch.allclose(m.weight, A))
True

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources