orthogonal(module, name='weight', orthogonal_map=None, *, use_trivialization=True)¶
Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices.
Letting be or , the parametrized matrix is orthogonal as
where is the conjugate transpose when is complex and the transpose when is real-valued, and is the n-dimensional identity matrix. In plain words, will have orthonormal columns whenever and orthonormal rows otherwise.
If the tensor has more than two dimensions, we consider it as a batch of matrices of shape (…, m, n).
The matrix may be parametrized via three different
orthogonal_mapin terms of the original tensor:
"householder": computes a product of Householder reflectors (
"cayley"often make the parametrized weight converge faster than
"householder", but they are slower to compute for very thin or very wide matrices.
use_trivialization=True(default), the parametrization implements the “Dynamic Trivialization Framework”, where an extra matrix is stored under
module.parametrizations.weight.base. This helps the convergence of the parametrized layer at the expense of some extra memory use. See Trivializations for Gradient-Based Optimization on Manifolds .
Initial value of : If the original tensor is not parametrized and
use_trivialization=True(default), the initial value of is that of the original tensor if it is orthogonal (or unitary in the complex case) and it is orthogonalized via the QR decomposition otherwise (see
torch.linalg.qr()). Same happens when it is not parametrized and
use_trivialization=False. Otherwise, the initial value is the result of the composition of all the registered parametrizations applied to the original tensor.
This function is implemented using the parametrization functionality in
module (nn.Module) – module on which to register the parametrization.
name (str, optional) – name of the tensor to make orthogonal. Default:
orthogonal_map (str, optional) – One of the following:
"matrix_exp"if the matrix is square or complex,
use_trivialization (bool, optional) – whether to use the dynamic trivialization framework. Default:
The original module with an orthogonal parametrization registered to the specified weight
>>> orth_linear = orthogonal(nn.Linear(20, 40)) >>> orth_linear ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _Orthogonal() ) ) ) >>> Q = orth_linear.weight >>> torch.dist(Q.T @ Q, torch.eye(20)) tensor(4.9332e-07)