Shortcuts

torch.nn.utils.parametrizations.weight_norm

torch.nn.utils.parametrizations.weight_norm(module, name='weight', dim=0)[source]

Apply weight normalization to a parameter in the given module.

w=gvv\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}

Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by name with two parameters: one specifying the magnitude and one specifying the direction.

By default, with dim=0, the norm is computed independently per output channel/plane. To compute a norm over the entire weight tensor, use dim=None.

See https://arxiv.org/abs/1602.07868

Parameters
  • module (Module) – containing module

  • name (str, optional) – name of weight parameter

  • dim (int, optional) – dimension over which to compute the norm

Returns

The original module with the weight norm hook

Example:

>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
ParametrizedLinear(
  in_features=20, out_features=40, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): _WeightNorm()
    )
  )
)
>>> m.parametrizations.weight.original0.size()
torch.Size([40, 1])
>>> m.parametrizations.weight.original1.size()
torch.Size([40, 20])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources