# mypy: allow-untyped-defs"""Spectral Normalization from https://arxiv.org/abs/1802.05957."""fromtypingimportAny,Optional,TypeVarimporttorchimporttorch.nn.functionalasFfromtorch.nn.modulesimportModule__all__=["SpectralNorm","SpectralNormLoadStateDictPreHook","SpectralNormStateDictHook","spectral_norm","remove_spectral_norm",]classSpectralNorm:# Invariant before and after each forward call:# u = F.normalize(W @ v)# NB: At initialization, this invariant is not enforced_version:int=1# At version 1:# made `W` not a buffer,# added `v` as a buffer, and# made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.name:strdim:intn_power_iterations:inteps:floatdef__init__(self,name:str="weight",n_power_iterations:int=1,dim:int=0,eps:float=1e-12,)->None:self.name=nameself.dim=dimifn_power_iterations<=0:raiseValueError("Expected n_power_iterations to be positive, but "f"got n_power_iterations={n_power_iterations}")self.n_power_iterations=n_power_iterationsself.eps=epsdefreshape_weight_to_matrix(self,weight:torch.Tensor)->torch.Tensor:weight_mat=weightifself.dim!=0:# permute dim to frontweight_mat=weight_mat.permute(self.dim,*[dfordinrange(weight_mat.dim())ifd!=self.dim])height=weight_mat.size(0)returnweight_mat.reshape(height,-1)defcompute_weight(self,module:Module,do_power_iteration:bool)->torch.Tensor:# NB: If `do_power_iteration` is set, the `u` and `v` vectors are# updated in power iteration **in-place**. This is very important# because in `DataParallel` forward, the vectors (being buffers) are# broadcast from the parallelized module to each module replica,# which is a new module object created on the fly. And each replica# runs its own spectral norm power iteration. So simply assigning# the updated vectors to the module this function runs on will cause# the update to be lost forever. And the next time the parallelized# module is replicated, the same randomly initialized vectors are# broadcast and used!## Therefore, to make the change propagate back, we rely on two# important behaviors (also enforced via tests):# 1. `DataParallel` doesn't clone storage if the broadcast tensor# is already on correct device; and it makes sure that the# parallelized module is already on `device[0]`.# 2. If the out tensor in `out=` kwarg has correct shape, it will# just fill in the values.# Therefore, since the same power iteration is performed on all# devices, simply updating the tensors in-place will make sure that# the module replica on `device[0]` will update the _u vector on the# parallelized module (by shared storage).## However, after we update `u` and `v` in-place, we need to **clone**# them before using them to normalize the weight. This is to support# backproping through two forward passes, e.g., the common pattern in# GAN training: loss = D(real) - D(fake). Otherwise, engine will# complain that variables needed to do backward for the first forward# (i.e., the `u` and `v` vectors) are changed in the second forward.weight=getattr(module,self.name+"_orig")u=getattr(module,self.name+"_u")v=getattr(module,self.name+"_v")weight_mat=self.reshape_weight_to_matrix(weight)ifdo_power_iteration:withtorch.no_grad():for_inrange(self.n_power_iterations):# Spectral norm of weight equals to `u^T W v`, where `u` and `v`# are the first left and right singular vectors.# This power iteration produces approximations of `u` and `v`.v=F.normalize(torch.mv(weight_mat.t(),u),dim=0,eps=self.eps,out=v)u=F.normalize(torch.mv(weight_mat,v),dim=0,eps=self.eps,out=u)ifself.n_power_iterations>0:# See above on why we need to cloneu=u.clone(memory_format=torch.contiguous_format)v=v.clone(memory_format=torch.contiguous_format)sigma=torch.dot(u,torch.mv(weight_mat,v))weight=weight/sigmareturnweightdefremove(self,module:Module)->None:withtorch.no_grad():weight=self.compute_weight(module,do_power_iteration=False)delattr(module,self.name)delattr(module,self.name+"_u")delattr(module,self.name+"_v")delattr(module,self.name+"_orig")module.register_parameter(self.name,torch.nn.Parameter(weight.detach()))def__call__(self,module:Module,inputs:Any)->None:setattr(module,self.name,self.compute_weight(module,do_power_iteration=module.training),)def_solve_v_and_rescale(self,weight_mat,u,target_sigma):# Tries to returns a vector `v` s.t. `u = F.normalize(W @ v)`# (the invariant at top of this class) and `u @ W @ v = sigma`.# This uses pinverse in case W^T W is not invertible.v=torch.linalg.multi_dot([weight_mat.t().mm(weight_mat).pinverse(),weight_mat.t(),u.unsqueeze(1)]).squeeze(1)returnv.mul_(target_sigma/torch.dot(u,torch.mv(weight_mat,v)))@staticmethoddefapply(module:Module,name:str,n_power_iterations:int,dim:int,eps:float)->"SpectralNorm":forhookinmodule._forward_pre_hooks.values():ifisinstance(hook,SpectralNorm)andhook.name==name:raiseRuntimeError(f"Cannot register two spectral_norm hooks on the same parameter {name}")fn=SpectralNorm(name,n_power_iterations,dim,eps)weight=module._parameters[name]ifweightisNone:raiseValueError(f"`SpectralNorm` cannot be applied as parameter `{name}` is None")ifisinstance(weight,torch.nn.parameter.UninitializedParameter):raiseValueError("The module passed to `SpectralNorm` can't have uninitialized parameters. ""Make sure to run the dummy forward before applying spectral normalization")withtorch.no_grad():weight_mat=fn.reshape_weight_to_matrix(weight)h,w=weight_mat.size()# randomly initialize `u` and `v`u=F.normalize(weight.new_empty(h).normal_(0,1),dim=0,eps=fn.eps)v=F.normalize(weight.new_empty(w).normal_(0,1),dim=0,eps=fn.eps)delattr(module,fn.name)module.register_parameter(fn.name+"_orig",weight)# We still need to assign weight back as fn.name because all sorts of# things may assume that it exists, e.g., when initializing weights.# However, we can't directly assign as it could be an nn.Parameter and# gets added as a parameter. Instead, we register weight.data as a plain# attribute.setattr(module,fn.name,weight.data)module.register_buffer(fn.name+"_u",u)module.register_buffer(fn.name+"_v",v)module.register_forward_pre_hook(fn)module._register_state_dict_hook(SpectralNormStateDictHook(fn))module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))returnfn# This is a top level class because Py2 pickle doesn't like inner class nor an# instancemethod.classSpectralNormLoadStateDictPreHook:# See docstring of SpectralNorm._version on the changes to spectral_norm.def__init__(self,fn)->None:self.fn=fn# For state_dict with version None, (assuming that it has gone through at# least one training forward), we have## u = F.normalize(W_orig @ v)# W = W_orig / sigma, where sigma = u @ W_orig @ v## To compute `v`, we solve `W_orig @ x = u`, and let# v = x / (u @ W_orig @ x) * (W / W_orig).def__call__(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,)->None:fn=self.fnversion=local_metadata.get("spectral_norm",{}).get(fn.name+".version",None)ifversionisNoneorversion<1:weight_key=prefix+fn.nameif(versionisNoneandall(weight_key+sinstate_dictforsin("_orig","_u","_v"))andweight_keynotinstate_dict):# Detect if it is the updated state dict and just missing metadata.# This could happen if the users are crafting a state dict themselves,# so we just pretend that this is the newest.returnhas_missing_keys=Falseforsuffixin("_orig","","_u"):key=weight_key+suffixifkeynotinstate_dict:has_missing_keys=Trueifstrict:missing_keys.append(key)ifhas_missing_keys:returnwithtorch.no_grad():weight_orig=state_dict[weight_key+"_orig"]weight=state_dict.pop(weight_key)sigma=(weight_orig/weight).mean()weight_mat=fn.reshape_weight_to_matrix(weight_orig)u=state_dict[weight_key+"_u"]v=fn._solve_v_and_rescale(weight_mat,u,sigma)state_dict[weight_key+"_v"]=v# This is a top level class because Py2 pickle doesn't like inner class nor an# instancemethod.classSpectralNormStateDictHook:# See docstring of SpectralNorm._version on the changes to spectral_norm.def__init__(self,fn)->None:self.fn=fndef__call__(self,module,state_dict,prefix,local_metadata)->None:if"spectral_norm"notinlocal_metadata:local_metadata["spectral_norm"]={}key=self.fn.name+".version"ifkeyinlocal_metadata["spectral_norm"]:raiseRuntimeError(f"Unexpected key in metadata['spectral_norm']: {key}")local_metadata["spectral_norm"][key]=self.fn._versionT_module=TypeVar("T_module",bound=Module)
[docs]defspectral_norm(module:T_module,name:str="weight",n_power_iterations:int=1,eps:float=1e-12,dim:Optional[int]=None,)->T_module:r"""Apply spectral normalization to a parameter in the given module. .. math:: \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by rescaling the weight tensor with spectral norm :math:`\sigma` of the weight matrix calculated using power iteration method. If the dimension of the weight tensor is greater than 2, it is reshaped to 2D in power iteration method to get spectral norm. This is implemented via a hook that calculates spectral norm and rescales weight before every :meth:`~Module.forward` call. See `Spectral Normalization for Generative Adversarial Networks`_ . .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 Args: module (nn.Module): containing module name (str, optional): name of weight parameter n_power_iterations (int, optional): number of power iterations to calculate spectral norm eps (float, optional): epsilon for numerical stability in calculating norms dim (int, optional): dimension corresponding to number of outputs, the default is ``0``, except for modules that are instances of ConvTranspose{1,2,3}d, when it is ``1`` Returns: The original module with the spectral norm hook .. note:: This function has been reimplemented as :func:`torch.nn.utils.parametrizations.spectral_norm` using the new parametrization functionality in :func:`torch.nn.utils.parametrize.register_parametrization`. Please use the newer version. This function will be deprecated in a future version of PyTorch. Example:: >>> m = spectral_norm(nn.Linear(20, 40)) >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_u.size() torch.Size([40]) """ifdimisNone:ifisinstance(module,(torch.nn.ConvTranspose1d,torch.nn.ConvTranspose2d,torch.nn.ConvTranspose3d,),):dim=1else:dim=0SpectralNorm.apply(module,name,n_power_iterations,dim,eps)returnmodule
[docs]defremove_spectral_norm(module:T_module,name:str="weight")->T_module:r"""Remove the spectral normalization reparameterization from a module. Args: module (Module): containing module name (str, optional): name of weight parameter Example: >>> m = spectral_norm(nn.Linear(40, 10)) >>> remove_spectral_norm(m) """fork,hookinmodule._forward_pre_hooks.items():ifisinstance(hook,SpectralNorm)andhook.name==name:hook.remove(module)delmodule._forward_pre_hooks[k]breakelse:raiseValueError(f"spectral_norm of '{name}' not found in {module}")fork,hookinmodule._state_dict_hooks.items():ifisinstance(hook,SpectralNormStateDictHook)andhook.fn.name==name:delmodule._state_dict_hooks[k]breakfork,hookinmodule._load_state_dict_pre_hooks.items():ifisinstance(hook,SpectralNormLoadStateDictPreHook)andhook.fn.name==name:delmodule._load_state_dict_pre_hooks[k]breakreturnmodule
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.