fromenumimportEnum,autoimporttorchfromtorchimportTensorfrom..utilsimportparametrizefrom..modulesimportModulefrom..importfunctionalasFfromtypingimportOptionaldef_is_orthogonal(Q,eps=None):n,k=Q.size(-2),Q.size(-1)Id=torch.eye(k,dtype=Q.dtype,device=Q.device)# A reasonable eps, but not too largeeps=10.*n*torch.finfo(Q.dtype).epsreturntorch.allclose(Q.transpose(-2,-1).conj()@Q,Id,atol=eps)def_make_orthogonal(A):""" Assume that A is a tall matrix. Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative """X,tau=torch.geqrf(A)Q=torch.linalg.householder_product(X,tau)# The diagonal of X is the diagonal of R (which is always real) so we normalise by its signsQ*=X.diagonal(dim1=-2,dim2=-1).sgn().unsqueeze(-2)returnQclass_OrthMaps(Enum):matrix_exp=auto()cayley=auto()householder=auto()class_Orthogonal(Module):base:Tensordef__init__(self,weight,orthogonal_map:_OrthMaps,*,use_trivialization=True)->None:super().__init__()# Note [Householder complex]# For complex tensors, it is not possible to compute the tensor `tau` necessary for# linalg.householder_product from the reflectors.# To see this, note that the reflectors have a shape like:# 0 0 0# * 0 0# * * 0# which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters# to parametrize the unitary matrices. Saving tau on its own does not work either, because# not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise# them as independent tensors we would not maintain the constraint# An equivalent reasoning holds for rectangular matricesifweight.is_complex()andorthogonal_map==_OrthMaps.householder:raiseValueError("The householder parametrization does not support complex tensors.")self.shape=weight.shapeself.orthogonal_map=orthogonal_mapifuse_trivialization:self.register_buffer("base",None)defforward(self,X:torch.Tensor)->torch.Tensor:n,k=X.size(-2),X.size(-1)transposed=n<kiftransposed:X=X.transpose(-2,-1)n,k=k,n# Here n > k and X is a tall matrixifself.orthogonal_map==_OrthMaps.matrix_exporself.orthogonal_map==_OrthMaps.cayley:# We just need n x k - k(k-1)/2 parametersX=X.tril()ifn!=k:# Embed into a square matrixX=torch.cat([X,X.new_zeros(n,n-k).expand(*X.shape[:-2],-1,-1)],dim=-1)A=X-X.transpose(-2,-1).conj()# A is skew-symmetric (or skew-hermitian)ifself.orthogonal_map==_OrthMaps.matrix_exp:Q=torch.matrix_exp(A)elifself.orthogonal_map==_OrthMaps.cayley:# Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}Id=torch.eye(n,dtype=A.dtype,device=A.device)Q=torch.linalg.solve(torch.add(Id,A,alpha=-0.5),torch.add(Id,A,alpha=0.5))# Q is now orthogonal (or unitary) of size (..., n, n)ifn!=k:Q=Q[...,:k]# Q is now the size of the X (albeit perhaps transposed)else:# X is real here, as we do not support householder with complex numbersA=X.tril(diagonal=-1)tau=2./(1.+(A*A).sum(dim=-2))Q=torch.linalg.householder_product(A,tau)# The diagonal of X is 1's and -1's# We do not want to differentiate through this or update the diagonal of X hence the castingQ=Q*X.diagonal(dim1=-2,dim2=-1).int().unsqueeze(-2)ifhasattr(self,"base"):Q=self.base@Qiftransposed:Q=Q.transpose(-2,-1)returnQ@torch.autograd.no_grad()defright_inverse(self,Q:torch.Tensor)->torch.Tensor:ifQ.shape!=self.shape:raiseValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. "f"Got a tensor of shape {Q.shape}.")Q_init=Qn,k=Q.size(-2),Q.size(-1)transpose=n<kiftranspose:Q=Q.transpose(-2,-1)n,k=k,n# We always make sure to always copy Q in every pathifnothasattr(self,"base"):# Note [right_inverse expm cayley]# If we do not have use_trivialization=True, we just implement the inverse of the forward# map for the Householder. To see why, think that for the Cayley map,# we would need to find the matrix X \in R^{n x k} such that:# Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)# A = Y - Y.transpose(-2, -1).conj()# cayley(A)[:, :k]# gives the original tensor. It is not clear how to do this.# Perhaps via some algebraic manipulation involving the QR like that of# Corollary 2.2 in Edelman, Arias and Smith?ifself.orthogonal_map==_OrthMaps.cayleyorself.orthogonal_map==_OrthMaps.matrix_exp:raiseNotImplementedError("It is not possible to assign to the matrix exponential ""or the Cayley parametrizations when use_trivialization=False.")# If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.# Here Q is always real because we do not support householder and complex matrices.# See note [Householder complex]A,tau=torch.geqrf(Q)# We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could# decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition# The diagonal of Q is the diagonal of R from the qr decompositionA.diagonal(dim1=-2,dim2=-1).sign_()# Equality with zero is ok because LAPACK returns exactly zero when it does not want# to use a particular reflectionA.diagonal(dim1=-2,dim2=-1)[tau==0.]*=-1returnA.transpose(-2,-1)iftransposeelseAelse:ifn==k:# We check whether Q is orthogonalifnot_is_orthogonal(Q):Q=_make_orthogonal(Q)else:# Is orthogonalQ=Q.clone()else:# Complete Q into a full n x n orthogonal matrixN=torch.randn(*(Q.size()[:-2]+(n,n-k)),dtype=Q.dtype,device=Q.device)Q=torch.cat([Q,N],dim=-1)Q=_make_orthogonal(Q)self.base=Q# It is necessary to return the -Id, as we use the diagonal for the# Householder parametrization. Using -Id makes:# householder(torch.zeros(m,n)) == torch.eye(m,n)# Poor man's version of eye_likeneg_Id=torch.zeros_like(Q_init)neg_Id.diagonal(dim1=-2,dim2=-1).fill_(-1.)returnneg_Id
[docs]deforthogonal(module:Module,name:str='weight',orthogonal_map:Optional[str]=None,*,use_trivialization:bool=True)->Module:r"""Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices. Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as .. math:: \begin{align*} Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} \end{align*} where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex and the transpose when :math:`Q` is real-valued, and :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n` and orthonormal rows otherwise. If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`. The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor: - ``"matrix_exp"``/``"cayley"``: the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_ :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric :math:`A` to give an orthogonal matrix. - ``"householder"``: computes a product of Householder reflectors (:func:`~torch.linalg.householder_product`). ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than ``"householder"``, but they are slower to compute for very thin or very wide matrices. If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework", where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under ``module.parametrizations.weight[0].base``. This helps the convergence of the parametrized layer at the expense of some extra memory use. See `Trivializations for Gradient-Based Optimization on Manifolds`_ . Initial value of :math:`Q`: If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case) and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`). Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``. Otherwise, the initial value is the result of the composition of all the registered parametrizations applied to the original tensor. .. note:: This function is implemented using the parametrization functionality in :func:`~torch.nn.utils.parametrize.register_parametrization`. .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501 Args: module (nn.Module): module on which to register the parametrization. name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``. orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``. Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise. use_trivialization (bool, optional): whether to use the dynamic trivialization framework. Default: ``True``. Returns: The original module with an orthogonal parametrization registered to the specified weight Example:: >>> orth_linear = orthogonal(nn.Linear(20, 40)) >>> orth_linear ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _Orthogonal() ) ) ) >>> Q = orth_linear.weight >>> torch.dist(Q.T @ Q, torch.eye(20)) tensor(4.9332e-07) """weight=getattr(module,name,None)ifnotisinstance(weight,Tensor):raiseValueError("Module '{}' has no parameter ot buffer with name '{}'".format(module,name))# We could implement this for 1-dim tensors as the maps on the sphere# but I believe it'd bite more people than it'd helpifweight.ndim<2:raiseValueError("Expected a matrix or batch of matrices. "f"Got a tensor of {weight.ndim} dimensions.")iforthogonal_mapisNone:orthogonal_map="matrix_exp"ifweight.size(-2)==weight.size(-1)orweight.is_complex()else"householder"orth_enum=getattr(_OrthMaps,orthogonal_map,None)iforth_enumisNone:raiseValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". 'f'Got: {orthogonal_map}')orth=_Orthogonal(weight,orth_enum,use_trivialization=use_trivialization)parametrize.register_parametrization(module,name,orth,unsafe=True)returnmodule
class_SpectralNorm(Module):def__init__(self,weight:torch.Tensor,n_power_iterations:int=1,dim:int=0,eps:float=1e-12)->None:super().__init__()ndim=weight.ndimifdim>=ndimordim<-ndim:raiseIndexError("Dimension out of range (expected to be in range of "f"[-{ndim}, {ndim-1}] but got {dim})")ifn_power_iterations<=0:raiseValueError('Expected n_power_iterations to be positive, but ''got n_power_iterations={}'.format(n_power_iterations))self.dim=dimifdim>=0elsedim+ndimself.eps=epsifndim>1:# For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward)self.n_power_iterations=n_power_iterationsweight_mat=self._reshape_weight_to_matrix(weight)h,w=weight_mat.size()u=weight_mat.new_empty(h).normal_(0,1)v=weight_mat.new_empty(w).normal_(0,1)self.register_buffer('_u',F.normalize(u,dim=0,eps=self.eps))self.register_buffer('_v',F.normalize(v,dim=0,eps=self.eps))# Start with u, v initialized to some reasonable values by performing a number# of iterations of the power methodself._power_method(weight_mat,15)def_reshape_weight_to_matrix(self,weight:torch.Tensor)->torch.Tensor:# Preconditionassertweight.ndim>1ifself.dim!=0:# permute dim to frontweight=weight.permute(self.dim,*(dfordinrange(weight.dim())ifd!=self.dim))returnweight.flatten(1)@torch.autograd.no_grad()def_power_method(self,weight_mat:torch.Tensor,n_power_iterations:int)->None:# See original note at torch/nn/utils/spectral_norm.py# NB: If `do_power_iteration` is set, the `u` and `v` vectors are# updated in power iteration **in-place**. This is very important# because in `DataParallel` forward, the vectors (being buffers) are# broadcast from the parallelized module to each module replica,# which is a new module object created on the fly. And each replica# runs its own spectral norm power iteration. So simply assigning# the updated vectors to the module this function runs on will cause# the update to be lost forever. And the next time the parallelized# module is replicated, the same randomly initialized vectors are# broadcast and used!## Therefore, to make the change propagate back, we rely on two# important behaviors (also enforced via tests):# 1. `DataParallel` doesn't clone storage if the broadcast tensor# is already on correct device; and it makes sure that the# parallelized module is already on `device[0]`.# 2. If the out tensor in `out=` kwarg has correct shape, it will# just fill in the values.# Therefore, since the same power iteration is performed on all# devices, simply updating the tensors in-place will make sure that# the module replica on `device[0]` will update the _u vector on the# parallized module (by shared storage).## However, after we update `u` and `v` in-place, we need to **clone**# them before using them to normalize the weight. This is to support# backproping through two forward passes, e.g., the common pattern in# GAN training: loss = D(real) - D(fake). Otherwise, engine will# complain that variables needed to do backward for the first forward# (i.e., the `u` and `v` vectors) are changed in the second forward.# Preconditionassertweight_mat.ndim>1for_inrange(n_power_iterations):# Spectral norm of weight equals to `u^T W v`, where `u` and `v`# are the first left and right singular vectors.# This power iteration produces approximations of `u` and `v`.self._u=F.normalize(torch.mv(weight_mat,self._v),# type: ignore[has-type]dim=0,eps=self.eps,out=self._u)# type: ignore[has-type]self._v=F.normalize(torch.mv(weight_mat.t(),self._u),dim=0,eps=self.eps,out=self._v)# type: ignore[has-type]defforward(self,weight:torch.Tensor)->torch.Tensor:ifweight.ndim==1:# Faster and more exact path, no need to approximate anythingreturnF.normalize(weight,dim=0,eps=self.eps)else:weight_mat=self._reshape_weight_to_matrix(weight)ifself.training:self._power_method(weight_mat,self.n_power_iterations)# See above on why we need to cloneu=self._u.clone(memory_format=torch.contiguous_format)v=self._v.clone(memory_format=torch.contiguous_format)# The proper way of computing this should be through F.bilinear, but# it seems to have some efficiency issues:# https://github.com/pytorch/pytorch/issues/58093sigma=torch.dot(u,torch.mv(weight_mat,v))returnweight/sigmadefright_inverse(self,value:torch.Tensor)->torch.Tensor:# we may want to assert here that the passed value already# satisfies constraintsreturnvalue
[docs]defspectral_norm(module:Module,name:str='weight',n_power_iterations:int=1,eps:float=1e-12,dim:Optional[int]=None)->Module:r"""Applies spectral normalization to a parameter in the given module. .. math:: \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} When applied on a vector, it simplifies to .. math:: \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2} Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant of the model. :math:`\sigma` is approximated performing one iteration of the `power method`_ every time the weight is accessed. If the dimension of the weight tensor is greater than 2, it is reshaped to 2D in power iteration method to get spectral norm. See `Spectral Normalization for Generative Adversarial Networks`_ . .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 .. note:: This function is implemented using the parametrization functionality in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a reimplementation of :func:`torch.nn.utils.spectral_norm`. .. note:: When this constraint is registered, the singular vectors associated to the largest singular value are estimated rather than sampled at random. These are then updated performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor is accessed with the module on `training` mode. .. note:: If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`, is in training mode on removal, it will perform another power iteration. If you'd like to avoid this iteration, set the module to eval mode before its removal. Args: module (nn.Module): containing module name (str, optional): name of weight parameter. Default: ``"weight"``. n_power_iterations (int, optional): number of power iterations to calculate spectral norm. Default: ``1``. eps (float, optional): epsilon for numerical stability in calculating norms. Default: ``1e-12``. dim (int, optional): dimension corresponding to number of outputs. Default: ``0``, except for modules that are instances of ConvTranspose{1,2,3}d, when it is ``1`` Returns: The original module with a new parametrization registered to the specified weight Example:: >>> snm = spectral_norm(nn.Linear(20, 40)) >>> snm ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _SpectralNorm() ) ) ) >>> torch.linalg.matrix_norm(snm.weight, 2) tensor(1.0000, grad_fn=<CopyBackwards>) """weight=getattr(module,name,None)ifnotisinstance(weight,Tensor):raiseValueError("Module '{}' has no parameter or buffer with name '{}'".format(module,name))ifdimisNone:ifisinstance(module,(torch.nn.ConvTranspose1d,torch.nn.ConvTranspose2d,torch.nn.ConvTranspose3d)):dim=1else:dim=0parametrize.register_parametrization(module,name,_SpectralNorm(weight,n_power_iterations,dim,eps))returnmodule
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.