# mypy: allow-untyped-defsr"""Weight Normalization from https://arxiv.org/abs/1602.07868."""fromtypingimportAny,TypeVarfromtyping_extensionsimportdeprecatedfromtorchimport_weight_norm,norm_except_dimfromtorch.nn.modulesimportModulefromtorch.nn.parameterimportParameter,UninitializedParameter__all__=["WeightNorm","weight_norm","remove_weight_norm"]classWeightNorm:name:strdim:intdef__init__(self,name:str,dim:int)->None:ifdimisNone:dim=-1self.name=nameself.dim=dim# TODO Make return type more specificdefcompute_weight(self,module:Module)->Any:g=getattr(module,self.name+"_g")v=getattr(module,self.name+"_v")return_weight_norm(v,g,self.dim)@staticmethod@deprecated("`torch.nn.utils.weight_norm` is deprecated ""in favor of `torch.nn.utils.parametrizations.weight_norm`.",category=FutureWarning,)defapply(module,name:str,dim:int)->"WeightNorm":forhookinmodule._forward_pre_hooks.values():ifisinstance(hook,WeightNorm)andhook.name==name:raiseRuntimeError(f"Cannot register two weight_norm hooks on the same parameter {name}")ifdimisNone:dim=-1fn=WeightNorm(name,dim)weight=getattr(module,name)ifisinstance(weight,UninitializedParameter):raiseValueError("The module passed to `WeightNorm` can't have uninitialized parameters. ""Make sure to run the dummy forward before applying weight normalization")# remove w from parameter listdelmodule._parameters[name]# add g and v as new parameters and express w as g/||v|| * vmodule.register_parameter(name+"_g",Parameter(norm_except_dim(weight,2,dim).data))module.register_parameter(name+"_v",Parameter(weight.data))setattr(module,name,fn.compute_weight(module))# recompute weight before every forward()module.register_forward_pre_hook(fn)returnfndefremove(self,module:Module)->None:weight=self.compute_weight(module)delattr(module,self.name)delmodule._parameters[self.name+"_g"]delmodule._parameters[self.name+"_v"]setattr(module,self.name,Parameter(weight.data))def__call__(self,module:Module,inputs:Any)->None:setattr(module,self.name,self.compute_weight(module))T_module=TypeVar("T_module",bound=Module)
[docs]defweight_norm(module:T_module,name:str="weight",dim:int=0)->T_module:r"""Apply weight normalization to a parameter in the given module. .. math:: \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). Weight normalization is implemented via a hook that recomputes the weight tensor from the magnitude and direction before every :meth:`~Module.forward` call. By default, with ``dim=0``, the norm is computed independently per output channel/plane. To compute a norm over the entire weight tensor, use ``dim=None``. See https://arxiv.org/abs/1602.07868 .. warning:: This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm` which uses the modern parametrization API. The new ``weight_norm`` is compatible with ``state_dict`` generated from old ``weight_norm``. Migration guide: * The magnitude (``weight_g``) and direction (``weight_v``) are now expressed as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1`` respectively. If this is bothering you, please comment on https://github.com/pytorch/pytorch/issues/102999 * To remove the weight normalization reparametrization, use :func:`torch.nn.utils.parametrize.remove_parametrizations`. * The weight is no longer recomputed once at module forward; instead, it will be recomputed on every access. To restore the old behavior, use :func:`torch.nn.utils.parametrize.cached` before invoking the module in question. Args: module (Module): containing module name (str, optional): name of weight parameter dim (int, optional): dimension over which to compute the norm Returns: The original module with the weight norm hook Example:: >>> m = weight_norm(nn.Linear(20, 40), name='weight') >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_g.size() torch.Size([40, 1]) >>> m.weight_v.size() torch.Size([40, 20]) """WeightNorm.apply(module,name,dim)returnmodule
[docs]defremove_weight_norm(module:T_module,name:str="weight")->T_module:r"""Remove the weight normalization reparameterization from a module. Args: module (Module): containing module name (str, optional): name of weight parameter Example: >>> m = weight_norm(nn.Linear(20, 40)) >>> remove_weight_norm(m) """fork,hookinmodule._forward_pre_hooks.items():ifisinstance(hook,WeightNorm)andhook.name==name:hook.remove(module)delmodule._forward_pre_hooks[k]returnmoduleraiseValueError(f"weight_norm of '{name}' not found in {module}")
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.