.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/parametrizations.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_parametrizations.py: Parametrizations Tutorial ========================= **Author**: `Mario Lezcano `_ Regularizing deep-learning models is a surprisingly challenging task. Classical techniques such as penalty methods often fall short when applied on deep models due to the complexity of the function being optimized. This is particularly problematic when working with ill-conditioned models. Examples of these are RNNs trained on long sequences and GANs. A number of techniques have been proposed in recent years to regularize these models and improve their convergence. On recurrent models, it has been proposed to control the singular values of the recurrent kernel for the RNN to be well-conditioned. This can be achieved, for example, by making the recurrent kernel `orthogonal `_. Another way to regularize recurrent models is via "`weight normalization `_". This approach proposes to decouple the learning of the parameters from the learning of their norms. To do so, the parameter is divided by its `Frobenius norm `_ and a separate parameter encoding its norm is learned. A similar regularization was proposed for GANs under the name of "`spectral normalization `_". This method controls the Lipschitz constant of the network by dividing its parameters by their `spectral norm `_, rather than their Frobenius norm. All these methods have a common pattern: they all transform a parameter in an appropriate way before using it. In the first case, they make it orthogonal by using a function that maps matrices to orthogonal matrices. In the case of weight and spectral normalization, they divide the original parameter by its norm. More generally, all these examples use a function to put extra structure on the parameters. In other words, they use a function to constrain the parameters. In this tutorial, you will learn how to implement and use this pattern to put constraints on your model. Doing so is as easy as writing your own ``nn.Module``. Requirements: ``torch>=1.9.0`` Implementing parametrizations by hand ------------------------------------- Assume that we want to have a square linear layer with symmetric weights, that is, with weights ``X`` such that ``X = Xᵀ``. One way to do so is to copy the upper-triangular part of the matrix into its lower-triangular part .. GENERATED FROM PYTHON SOURCE LINES 49-62 .. code-block:: default import torch import torch.nn as nn import torch.nn.utils.parametrize as parametrize def symmetric(X): return X.triu() + X.triu(1).transpose(-1, -2) X = torch.rand(3, 3) A = symmetric(X) assert torch.allclose(A, A.T) # A is symmetric print(A) # Quick visual check .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[0.8823, 0.9150, 0.3829], [0.9150, 0.3904, 0.6009], [0.3829, 0.6009, 0.9408]]) .. GENERATED FROM PYTHON SOURCE LINES 63-64 We can then use this idea to implement a linear layer with symmetric weights .. GENERATED FROM PYTHON SOURCE LINES 64-73 .. code-block:: default class LinearSymmetric(nn.Module): def __init__(self, n_features): super().__init__() self.weight = nn.Parameter(torch.rand(n_features, n_features)) def forward(self, x): A = symmetric(self.weight) return x @ A .. GENERATED FROM PYTHON SOURCE LINES 74-75 The layer can be then used as a regular linear layer .. GENERATED FROM PYTHON SOURCE LINES 75-78 .. code-block:: default layer = LinearSymmetric(3) out = layer(torch.rand(8, 3)) .. GENERATED FROM PYTHON SOURCE LINES 79-98 This implementation, although correct and self-contained, presents a number of problems: 1) It reimplements the layer. We had to implement the linear layer as ``x @ A``. This is not very problematic for a linear layer, but imagine having to reimplement a CNN or a Transformer... 2) It does not separate the layer and the parametrization. If the parametrization were more difficult, we would have to rewrite its code for each layer that we want to use it in. 3) It recomputes the parametrization every time we use the layer. If we use the layer several times during the forward pass, (imagine the recurrent kernel of an RNN), it would compute the same ``A`` every time that the layer is called. Introduction to parametrizations -------------------------------- Parametrizations can solve all these problems as well as others. Let's start by reimplementing the code above using ``torch.nn.utils.parametrize``. The only thing that we have to do is to write the parametrization as a regular ``nn.Module`` .. GENERATED FROM PYTHON SOURCE LINES 98-102 .. code-block:: default class Symmetric(nn.Module): def forward(self, X): return X.triu() + X.triu(1).transpose(-1, -2) .. GENERATED FROM PYTHON SOURCE LINES 103-105 This is all we need to do. Once we have this, we can transform any regular layer into a symmetric layer by doing .. GENERATED FROM PYTHON SOURCE LINES 105-108 .. code-block:: default layer = nn.Linear(3, 3) parametrize.register_parametrization(layer, "weight", Symmetric()) .. rst-class:: sphx-glr-script-out .. code-block:: none ParametrizedLinear( in_features=3, out_features=3, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): Symmetric() ) ) ) .. GENERATED FROM PYTHON SOURCE LINES 109-110 Now, the matrix of the linear layer is symmetric .. GENERATED FROM PYTHON SOURCE LINES 110-114 .. code-block:: default A = layer.weight assert torch.allclose(A, A.T) # A is symmetric print(A) # Quick visual check .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ 0.2430, 0.5155, 0.3337], [ 0.5155, 0.3333, 0.1033], [ 0.3337, 0.1033, -0.5715]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 115-119 We can do the same thing with any other layer. For example, we can create a CNN with `skew-symmetric `_ kernels. We use a similar parametrization, copying the upper-triangular part with signs reversed into the lower-triangular part .. GENERATED FROM PYTHON SOURCE LINES 119-131 .. code-block:: default class Skew(nn.Module): def forward(self, X): A = X.triu(1) return A - A.transpose(-1, -2) cnn = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=3) parametrize.register_parametrization(cnn, "weight", Skew()) # Print a few kernels print(cnn.weight[0, 1]) print(cnn.weight[2, 2]) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ 0.0000, 0.0457, -0.0311], [-0.0457, 0.0000, -0.0889], [ 0.0311, 0.0889, 0.0000]], grad_fn=) tensor([[ 0.0000, -0.1314, 0.0626], [ 0.1314, 0.0000, 0.1280], [-0.0626, -0.1280, 0.0000]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 132-150 Inspecting a parametrized module -------------------------------- When a module is parametrized, we find that the module has changed in three ways: 1) ``model.weight`` is now a property 2) It has a new ``module.parametrizations`` attribute 3) The unparametrized weight has been moved to ``module.parametrizations.weight.original`` | After parametrizing ``weight``, ``layer.weight`` is turned into a `Python property `_. This property computes ``parametrization(weight)`` every time we request ``layer.weight`` just as we did in our implementation of ``LinearSymmetric`` above. Registered parametrizations are stored under a ``parametrizations`` attribute within the module. .. GENERATED FROM PYTHON SOURCE LINES 150-155 .. code-block:: default layer = nn.Linear(3, 3) print(f"Unparametrized:\n{layer}") parametrize.register_parametrization(layer, "weight", Symmetric()) print(f"\nParametrized:\n{layer}") .. rst-class:: sphx-glr-script-out .. code-block:: none Unparametrized: Linear(in_features=3, out_features=3, bias=True) Parametrized: ParametrizedLinear( in_features=3, out_features=3, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): Symmetric() ) ) ) .. GENERATED FROM PYTHON SOURCE LINES 156-157 This ``parametrizations`` attribute is an ``nn.ModuleDict``, and it can be accessed as such .. GENERATED FROM PYTHON SOURCE LINES 157-160 .. code-block:: default print(layer.parametrizations) print(layer.parametrizations.weight) .. rst-class:: sphx-glr-script-out .. code-block:: none ModuleDict( (weight): ParametrizationList( (0): Symmetric() ) ) ParametrizationList( (0): Symmetric() ) .. GENERATED FROM PYTHON SOURCE LINES 161-165 Each element of this ``nn.ModuleDict`` is a ``ParametrizationList``, which behaves like an ``nn.Sequential``. This list will allow us to concatenate parametrizations on one weight. Since this is a list, we can access the parametrizations indexing it. Here's where our ``Symmetric`` parametrization sits .. GENERATED FROM PYTHON SOURCE LINES 165-167 .. code-block:: default print(layer.parametrizations.weight[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none Symmetric() .. GENERATED FROM PYTHON SOURCE LINES 168-170 The other thing that we notice is that, if we print the parameters, we see that the parameter ``weight`` has been moved .. GENERATED FROM PYTHON SOURCE LINES 170-172 .. code-block:: default print(dict(layer.named_parameters())) .. rst-class:: sphx-glr-script-out .. code-block:: none {'bias': Parameter containing: tensor([-0.0730, -0.2283, 0.3217], requires_grad=True), 'parametrizations.weight.original': Parameter containing: tensor([[-0.4328, 0.3425, 0.4643], [ 0.0937, -0.1005, -0.5348], [-0.2103, 0.1470, 0.2722]], requires_grad=True)} .. GENERATED FROM PYTHON SOURCE LINES 173-174 It now sits under ``layer.parametrizations.weight.original`` .. GENERATED FROM PYTHON SOURCE LINES 174-176 .. code-block:: default print(layer.parametrizations.weight.original) .. rst-class:: sphx-glr-script-out .. code-block:: none Parameter containing: tensor([[-0.4328, 0.3425, 0.4643], [ 0.0937, -0.1005, -0.5348], [-0.2103, 0.1470, 0.2722]], requires_grad=True) .. GENERATED FROM PYTHON SOURCE LINES 177-179 Besides these three small differences, the parametrization is doing exactly the same as our manual implementation .. GENERATED FROM PYTHON SOURCE LINES 179-183 .. code-block:: default symmetric = Symmetric() weight_orig = layer.parametrizations.weight.original print(torch.dist(layer.weight, symmetric(weight_orig))) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor(0., grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 184-198 Parametrizations are first-class citizens ----------------------------------------- Since ``layer.parametrizations`` is an ``nn.ModuleList``, it means that the parametrizations are properly registered as submodules of the original module. As such, the same rules for registering parameters in a module apply to register a parametrization. For example, if a parametrization has parameters, these will be moved from CPU to CUDA when calling ``model = model.cuda()``. Caching the value of a parametrization -------------------------------------- Parametrizations come with an inbuilt caching system via the context manager ``parametrize.cached()`` .. GENERATED FROM PYTHON SOURCE LINES 198-213 .. code-block:: default class NoisyParametrization(nn.Module): def forward(self, X): print("Computing the Parametrization") return X layer = nn.Linear(4, 4) parametrize.register_parametrization(layer, "weight", NoisyParametrization()) print("Here, layer.weight is recomputed every time we call it") foo = layer.weight + layer.weight.T bar = layer.weight.sum() with parametrize.cached(): print("Here, it is computed just the first time layer.weight is called") foo = layer.weight + layer.weight.T bar = layer.weight.sum() .. rst-class:: sphx-glr-script-out .. code-block:: none Computing the Parametrization Here, layer.weight is recomputed every time we call it Computing the Parametrization Computing the Parametrization Computing the Parametrization Here, it is computed just the first time layer.weight is called Computing the Parametrization .. GENERATED FROM PYTHON SOURCE LINES 214-223 Concatenating parametrizations ------------------------------ Concatenating two parametrizations is as easy as registering them on the same tensor. We may use this to create more complex parametrizations from simpler ones. For example, the `Cayley map `_ maps the skew-symmetric matrices to the orthogonal matrices of positive determinant. We can concatenate ``Skew`` and a parametrization that implements the Cayley map to get a layer with orthogonal weights .. GENERATED FROM PYTHON SOURCE LINES 223-238 .. code-block:: default class CayleyMap(nn.Module): def __init__(self, n): super().__init__() self.register_buffer("Id", torch.eye(n)) def forward(self, X): # (I + X)(I - X)^{-1} return torch.linalg.solve(self.Id - X, self.Id + X) layer = nn.Linear(3, 3) parametrize.register_parametrization(layer, "weight", Skew()) parametrize.register_parametrization(layer, "weight", CayleyMap(3)) X = layer.weight print(torch.dist(X.T @ X, torch.eye(3))) # X is orthogonal .. rst-class:: sphx-glr-script-out .. code-block:: none tensor(1.2991e-07, grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 239-243 This may also be used to prune a parametrized module, or to reuse parametrizations. For example, the matrix exponential maps the symmetric matrices to the Symmetric Positive Definite (SPD) matrices But the matrix exponential also maps the skew-symmetric matrices to the orthogonal matrices. Using these two facts, we may reuse the parametrizations before to our advantage .. GENERATED FROM PYTHON SOURCE LINES 243-260 .. code-block:: default class MatrixExponential(nn.Module): def forward(self, X): return torch.matrix_exp(X) layer_orthogonal = nn.Linear(3, 3) parametrize.register_parametrization(layer_orthogonal, "weight", Skew()) parametrize.register_parametrization(layer_orthogonal, "weight", MatrixExponential()) X = layer_orthogonal.weight print(torch.dist(X.T @ X, torch.eye(3))) # X is orthogonal layer_spd = nn.Linear(3, 3) parametrize.register_parametrization(layer_spd, "weight", Symmetric()) parametrize.register_parametrization(layer_spd, "weight", MatrixExponential()) X = layer_spd.weight print(torch.dist(X, X.T)) # X is symmetric print((torch.linalg.eigvalsh(X) > 0.).all()) # X is positive definite .. rst-class:: sphx-glr-script-out .. code-block:: none tensor(1.9066e-07, grad_fn=) tensor(4.2147e-08, grad_fn=) tensor(True) .. GENERATED FROM PYTHON SOURCE LINES 261-274 Initializing parametrizations ----------------------------- Parametrizations come with a mechanism to initialize them. If we implement a method ``right_inverse`` with signature .. code-block:: python def right_inverse(self, X: Tensor) -> Tensor it will be used when assigning to the parametrized tensor. Let's upgrade our implementation of the ``Skew`` class to support this .. GENERATED FROM PYTHON SOURCE LINES 274-284 .. code-block:: default class Skew(nn.Module): def forward(self, X): A = X.triu(1) return A - A.transpose(-1, -2) def right_inverse(self, A): # We assume that A is skew-symmetric # We take the upper-triangular elements, as these are those used in the forward return A.triu(1) .. GENERATED FROM PYTHON SOURCE LINES 285-286 We may now initialize a layer that is parametrized with ``Skew`` .. GENERATED FROM PYTHON SOURCE LINES 286-293 .. code-block:: default layer = nn.Linear(3, 3) parametrize.register_parametrization(layer, "weight", Skew()) X = torch.rand(3, 3) X = X - X.T # X is now skew-symmetric layer.weight = X # Initialize layer.weight to be X print(torch.dist(layer.weight, X)) # layer.weight == X .. rst-class:: sphx-glr-script-out .. code-block:: none tensor(0., grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 294-296 This ``right_inverse`` works as expected when we concatenate parametrizations. To see this, let's upgrade the Cayley parametrization to also support being initialized .. GENERATED FROM PYTHON SOURCE LINES 296-323 .. code-block:: default class CayleyMap(nn.Module): def __init__(self, n): super().__init__() self.register_buffer("Id", torch.eye(n)) def forward(self, X): # Assume X skew-symmetric # (I + X)(I - X)^{-1} return torch.linalg.solve(self.Id - X, self.Id + X) def right_inverse(self, A): # Assume A orthogonal # See https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map # (A - I)(A + I)^{-1} return torch.linalg.solve(A + self.Id, self.Id - A) layer_orthogonal = nn.Linear(3, 3) parametrize.register_parametrization(layer_orthogonal, "weight", Skew()) parametrize.register_parametrization(layer_orthogonal, "weight", CayleyMap(3)) # Sample an orthogonal matrix with positive determinant X = torch.empty(3, 3) nn.init.orthogonal_(X) if X.det() < 0.: X[0].neg_() layer_orthogonal.weight = X print(torch.dist(layer_orthogonal.weight, X)) # layer_orthogonal.weight == X .. rst-class:: sphx-glr-script-out .. code-block:: none tensor(2.2141, grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 324-325 This initialization step can be written more succinctly as .. GENERATED FROM PYTHON SOURCE LINES 325-327 .. code-block:: default layer_orthogonal.weight = nn.init.orthogonal_(layer_orthogonal.weight) .. GENERATED FROM PYTHON SOURCE LINES 328-334 The name of this method comes from the fact that we would often expect that ``forward(right_inverse(X)) == X``. This is a direct way of rewriting that the forward after the initialization with value ``X`` should return the value ``X``. This constraint is not strongly enforced in practice. In fact, at times, it might be of interest to relax this relation. For example, consider the following implementation of a randomized pruning method: .. GENERATED FROM PYTHON SOURCE LINES 334-347 .. code-block:: default class PruningParametrization(nn.Module): def __init__(self, X, p_drop=0.2): super().__init__() # sample zeros with probability p_drop mask = torch.full_like(X, 1.0 - p_drop) self.mask = torch.bernoulli(mask) def forward(self, X): return X * self.mask def right_inverse(self, A): return A .. GENERATED FROM PYTHON SOURCE LINES 348-352 In this case, it is not true that for every matrix A ``forward(right_inverse(A)) == A``. This is only true when the matrix ``A`` has zeros in the same positions as the mask. Even then, if we assign a tensor to a pruned parameter, it will comes as no surprise that tensor will be, in fact, pruned .. GENERATED FROM PYTHON SOURCE LINES 352-359 .. code-block:: default layer = nn.Linear(3, 4) X = torch.rand_like(layer.weight) print(f"Initialization matrix:\n{X}") parametrize.register_parametrization(layer, "weight", PruningParametrization(layer.weight)) layer.weight = X print(f"\nInitialized weight:\n{layer.weight}") .. rst-class:: sphx-glr-script-out .. code-block:: none Initialization matrix: tensor([[0.3513, 0.3546, 0.7670], [0.2533, 0.2636, 0.8081], [0.0643, 0.5611, 0.9417], [0.5857, 0.6360, 0.2088]]) Initialized weight: tensor([[0.3513, 0.3546, 0.7670], [0.2533, 0.0000, 0.8081], [0.0643, 0.5611, 0.9417], [0.5857, 0.6360, 0.0000]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 360-365 Removing parametrizations ------------------------- We may remove all the parametrizations from a parameter or a buffer in a module by using ``parametrize.remove_parametrizations()`` .. GENERATED FROM PYTHON SOURCE LINES 365-378 .. code-block:: default layer = nn.Linear(3, 3) print("Before:") print(layer) print(layer.weight) parametrize.register_parametrization(layer, "weight", Skew()) print("\nParametrized:") print(layer) print(layer.weight) parametrize.remove_parametrizations(layer, "weight") print("\nAfter. Weight has skew-symmetric values but it is unconstrained:") print(layer) print(layer.weight) .. rst-class:: sphx-glr-script-out .. code-block:: none Before: Linear(in_features=3, out_features=3, bias=True) Parameter containing: tensor([[ 0.0669, -0.3112, 0.3017], [-0.5464, -0.2233, -0.1125], [-0.4906, -0.3671, -0.0942]], requires_grad=True) Parametrized: ParametrizedLinear( in_features=3, out_features=3, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): Skew() ) ) ) tensor([[ 0.0000, -0.3112, 0.3017], [ 0.3112, 0.0000, -0.1125], [-0.3017, 0.1125, 0.0000]], grad_fn=) After. Weight has skew-symmetric values but it is unconstrained: Linear(in_features=3, out_features=3, bias=True) Parameter containing: tensor([[ 0.0000, -0.3112, 0.3017], [ 0.3112, 0.0000, -0.1125], [-0.3017, 0.1125, 0.0000]], requires_grad=True) .. GENERATED FROM PYTHON SOURCE LINES 379-382 When removing a parametrization, we may choose to leave the original parameter (i.e. that in ``layer.parametriations.weight.original``) rather than its parametrized version by setting the flag ``leave_parametrized=False`` .. GENERATED FROM PYTHON SOURCE LINES 382-394 .. code-block:: default layer = nn.Linear(3, 3) print("Before:") print(layer) print(layer.weight) parametrize.register_parametrization(layer, "weight", Skew()) print("\nParametrized:") print(layer) print(layer.weight) parametrize.remove_parametrizations(layer, "weight", leave_parametrized=False) print("\nAfter. Same as Before:") print(layer) print(layer.weight) .. rst-class:: sphx-glr-script-out .. code-block:: none Before: Linear(in_features=3, out_features=3, bias=True) Parameter containing: tensor([[-0.3447, -0.3777, 0.5038], [ 0.2042, 0.0153, 0.0781], [-0.4640, -0.1928, 0.5558]], requires_grad=True) Parametrized: ParametrizedLinear( in_features=3, out_features=3, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): Skew() ) ) ) tensor([[ 0.0000, -0.3777, 0.5038], [ 0.3777, 0.0000, 0.0781], [-0.5038, -0.0781, 0.0000]], grad_fn=) After. Same as Before: Linear(in_features=3, out_features=3, bias=True) Parameter containing: tensor([[ 0.0000, -0.3777, 0.5038], [ 0.0000, 0.0000, 0.0781], [ 0.0000, 0.0000, 0.0000]], requires_grad=True) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.051 seconds) .. _sphx_glr_download_intermediate_parametrizations.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: parametrizations.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: parametrizations.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_