torch.nn.utils.prune.identity¶
- torch.nn.utils.prune.identity(module, name)[source]¶
Apply pruning reparametrization without pruning any units.
Applies pruning reparametrization to the tensor corresponding to the parameter called
name
inmodule
without actually pruning any units. Modifies module in place (and also return the modified module) by:adding a named buffer called
name+'_mask'
corresponding to the binary mask applied to the parametername
by the pruning method.replacing the parameter
name
by its pruned version, while the original (unpruned) parameter is stored in a new parameter namedname+'_orig'
.
Note
The mask is a tensor of ones.
- Parameters
- Returns
modified (i.e. pruned) version of the input module
- Return type
module (nn.Module)
Examples
>>> m = prune.identity(nn.Linear(2, 3), 'bias') >>> print(m.bias_mask) tensor([1., 1., 1.])