Shortcuts

RandomUnstructured

class torch.nn.utils.prune.RandomUnstructured(amount)[source][source]

Prune (currently unpruned) units in a tensor at random.

Parameters
  • name (str) – parameter name within module on which pruning will act.

  • amount (int or float) – quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune.

classmethod apply(module, name, amount)[source][source]

Add pruning on the fly and reparametrization of a tensor.

Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.

Parameters
  • module (nn.Module) – module containing the tensor to prune

  • name (str) – parameter name within module on which pruning will act.

  • amount (int or float) – quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune.

apply_mask(module)[source]

Simply handles the multiplication between the parameter being pruned and the generated mask.

Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.

Parameters

module (nn.Module) – module containing the tensor to prune

Returns

pruned version of the input tensor

Return type

pruned_tensor (torch.Tensor)

prune(t, default_mask=None, importance_scores=None)[source]

Compute and returns a pruned version of input tensor t.

According to the pruning rule specified in compute_mask().

Parameters
  • t (torch.Tensor) – tensor to prune (of same dimensions as default_mask).

  • importance_scores (torch.Tensor) – tensor of importance scores (of same shape as t) used to compute mask for pruning t. The values in this tensor indicate the importance of the corresponding elements in the t that is being pruned. If unspecified or None, the tensor t will be used in its place.

  • default_mask (torch.Tensor, optional) – mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, default to a mask of ones.

Returns

pruned version of tensor t.

remove(module)[source]

Remove the pruning reparameterization from a module.

The pruned parameter named name remains permanently pruned, and the parameter named name+'_orig' is removed from the parameter list. Similarly, the buffer named name+'_mask' is removed from the buffers.

Note

Pruning itself is NOT undone or reversed!

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources