torch.nn.utils.parametrizations.weight_norm¶
- torch.nn.utils.parametrizations.weight_norm(module, name='weight', dim=0)[source]¶
Apply weight normalization to a parameter in the given module.
Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by
name
with two parameters: one specifying the magnitude and one specifying the direction.By default, with
dim=0
, the norm is computed independently per output channel/plane. To compute a norm over the entire weight tensor, usedim=None
.See https://arxiv.org/abs/1602.07868
- Parameters
- Returns
The original module with the weight norm hook
Example:
>>> m = weight_norm(nn.Linear(20, 40), name='weight') >>> m ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _WeightNorm() ) ) ) >>> m.parametrizations.weight.original0.size() torch.Size([40, 1]) >>> m.parametrizations.weight.original1.size() torch.Size([40, 20])