r"""Weight Normalization from https://arxiv.org/abs/1602.07868"""fromtorch.nn.parameterimportParameter,UninitializedParameterfromtorchimport_weight_norm,norm_except_dimfromtypingimportAny,TypeVarfrom..modulesimportModule__all__=['WeightNorm','weight_norm','remove_weight_norm']classWeightNorm(object):name:strdim:intdef__init__(self,name:str,dim:int)->None:ifdimisNone:dim=-1self.name=nameself.dim=dim# TODO Make return type more specificdefcompute_weight(self,module:Module)->Any:g=getattr(module,self.name+'_g')v=getattr(module,self.name+'_v')return_weight_norm(v,g,self.dim)@staticmethoddefapply(module,name:str,dim:int)->'WeightNorm':fork,hookinmodule._forward_pre_hooks.items():ifisinstance(hook,WeightNorm)andhook.name==name:raiseRuntimeError("Cannot register two weight_norm hooks on ""the same parameter {}".format(name))ifdimisNone:dim=-1fn=WeightNorm(name,dim)weight=getattr(module,name)ifisinstance(weight,UninitializedParameter):raiseValueError('The module passed to `WeightNorm` can\'t have uninitialized parameters. ''Make sure to run the dummy forward before applying weight normalization')# remove w from parameter listdelmodule._parameters[name]# add g and v as new parameters and express w as g/||v|| * vmodule.register_parameter(name+'_g',Parameter(norm_except_dim(weight,2,dim).data))module.register_parameter(name+'_v',Parameter(weight.data))setattr(module,name,fn.compute_weight(module))# recompute weight before every forward()module.register_forward_pre_hook(fn)returnfndefremove(self,module:Module)->None:weight=self.compute_weight(module)delattr(module,self.name)delmodule._parameters[self.name+'_g']delmodule._parameters[self.name+'_v']setattr(module,self.name,Parameter(weight.data))def__call__(self,module:Module,inputs:Any)->None:setattr(module,self.name,self.compute_weight(module))T_module=TypeVar('T_module',bound=Module)
[docs]defweight_norm(module:T_module,name:str='weight',dim:int=0)->T_module:r"""Applies weight normalization to a parameter in the given module. .. math:: \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). Weight normalization is implemented via a hook that recomputes the weight tensor from the magnitude and direction before every :meth:`~Module.forward` call. By default, with ``dim=0``, the norm is computed independently per output channel/plane. To compute a norm over the entire weight tensor, use ``dim=None``. See https://arxiv.org/abs/1602.07868 Args: module (Module): containing module name (str, optional): name of weight parameter dim (int, optional): dimension over which to compute the norm Returns: The original module with the weight norm hook Example:: >>> m = weight_norm(nn.Linear(20, 40), name='weight') >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_g.size() torch.Size([40, 1]) >>> m.weight_v.size() torch.Size([40, 20]) """WeightNorm.apply(module,name,dim)returnmodule
[docs]defremove_weight_norm(module:T_module,name:str='weight')->T_module:r"""Removes the weight normalization reparameterization from a module. Args: module (Module): containing module name (str, optional): name of weight parameter Example: >>> m = weight_norm(nn.Linear(20, 40)) >>> remove_weight_norm(m) """fork,hookinmodule._forward_pre_hooks.items():ifisinstance(hook,WeightNorm)andhook.name==name:hook.remove(module)delmodule._forward_pre_hooks[k]returnmoduleraiseValueError("weight_norm of '{}' not found in {}".format(name,module))
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.