torch.nn.utils.weight_norm
- torch.nn.utils.weight_norm(module, name='weight', dim=0)[source]
Applies weight normalization to a parameter in the given module.
Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by
name
(e.g.'weight'
) with two parameters: one specifying the magnitude (e.g.'weight_g'
) and one specifying the direction (e.g.'weight_v'
). Weight normalization is implemented via a hook that recomputes the weight tensor from the magnitude and direction before everyforward()
call.By default, with
dim=0
, the norm is computed independently per output channel/plane. To compute a norm over the entire weight tensor, usedim=None
.See https://arxiv.org/abs/1602.07868
Warning
This function is deprecated. Use
torch.nn.utils.parametrizations.weight_norm()
which uses the modern parametrization API. The newweight_norm
is compatible withstate_dict
generated from oldweight_norm
.Migration guide:
The magnitude (
weight_g
) and direction (weight_v
) are now expressed asparametrizations.weight.original0
andparametrizations.weight.original1
respectively. If this is bothering you, please comment on https://github.com/pytorch/pytorch/issues/102999To remove the weight normalization reparametrization, use
torch.nn.utils.parametrize.remove_parametrizations()
.The weight is no longer recomputed once at module forward; instead, it will be recomputed on every access. To restore the old behavior, use
torch.nn.utils.parametrize.cached()
before invoking the module in question.
- Parameters
- Returns
The original module with the weight norm hook
- Return type
T_module
Example:
>>> m = weight_norm(nn.Linear(20, 40), name='weight') >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_g.size() torch.Size([40, 1]) >>> m.weight_v.size() torch.Size([40, 20])