class torch.nn.utils.prune.BasePruningMethod[source]

Abstract base class for creation of new pruning techniques.

Provides a skeleton for customization requiring the overriding of methods such as compute_mask() and apply().

classmethod apply(module, name, *args, **kwargs)[source]

Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.

  • module (nn.Module) – module containing the tensor to prune

  • name (str) – parameter name within module on which pruning will act.

  • args – arguments passed on to a subclass of BasePruningMethod

  • kwargs – keyword arguments passed on to a subclass of a BasePruningMethod


Simply handles the multiplication between the parameter being pruned and the generated mask. Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.


module (nn.Module) – module containing the tensor to prune


pruned version of the input tensor

Return type

pruned_tensor (torch.Tensor)

abstract compute_mask(t, default_mask)[source]

Computes and returns a mask for the input tensor t. Starting from a base default_mask (which should be a mask of ones if the tensor has not been pruned yet), generate a random mask to apply on top of the default_mask according to the specific pruning method recipe.

  • t (torch.Tensor) – tensor representing the parameter to prune

  • default_mask (torch.Tensor) – Base mask from previous pruning

  • that need to be respected after the new mask is (iterations,) –

  • Same dims as t. (applied.) –


mask to apply to t, of same dims as t

Return type

mask (torch.Tensor)

prune(t, default_mask=None)[source]

Computes and returns a pruned version of input tensor t according to the pruning rule specified in compute_mask().

  • t (torch.Tensor) – tensor to prune (of same dimensions as default_mask).

  • default_mask (torch.Tensor, optional) – mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, default to a mask of ones.


pruned version of tensor t.


Removes the pruning reparameterization from a module. The pruned parameter named name remains permanently pruned, and the parameter named name+'_orig' is removed from the parameter list. Similarly, the buffer named name+'_mask' is removed from the buffers.


Pruning itself is NOT undone or reversed!


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources