torch.nn.utils.prune.identity¶
- torch.nn.utils.prune.identity(module, name)[source]¶
Applies pruning reparametrization to the tensor corresponding to the parameter called
name
inmodule
without actually pruning any units. Modifies module in place (and also return the modified module) by:adding a named buffer called
name+'_mask'
corresponding to the binary mask applied to the parametername
by the pruning method.replacing the parameter
name
by its pruned version, while the original (unpruned) parameter is stored in a new parameter namedname+'_orig'
.
Note
The mask is a tensor of ones.
- Parameters:
- Returns:
modified (i.e. pruned) version of the input module
- Return type:
module (nn.Module)
Examples
>>> m = prune.identity(nn.Linear(2, 3), 'bias') >>> print(m.bias_mask) tensor([1., 1., 1.])