[docs]classBasePruningMethod(ABC):r"""Abstract base class for creation of new pruning techniques. Provides a skeleton for customization requiring the overriding of methods such as :meth:`compute_mask` and :meth:`apply`. """_tensor_name:strdef__call__(self,module,inputs):r"""Multiplies the mask (stored in ``module[name + '_mask']``) into the original tensor (stored in ``module[name + '_orig']``) and stores the result into ``module[name]`` by using :meth:`apply_mask`. Args: module (nn.Module): module containing the tensor to prune inputs: not used. """setattr(module,self._tensor_name,self.apply_mask(module))
[docs]@abstractmethoddefcompute_mask(self,t,default_mask):r"""Computes and returns a mask for the input tensor ``t``. Starting from a base ``default_mask`` (which should be a mask of ones if the tensor has not been pruned yet), generate a random mask to apply on top of the ``default_mask`` according to the specific pruning method recipe. Args: t (torch.Tensor): tensor representing the importance scores of the parameter to prune. default_mask (torch.Tensor): Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as ``t``. Returns: mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` """pass
[docs]defapply_mask(self,module):r"""Simply handles the multiplication between the parameter being pruned and the generated mask. Fetches the mask and the original tensor from the module and returns the pruned version of the tensor. Args: module (nn.Module): module containing the tensor to prune Returns: pruned_tensor (torch.Tensor): pruned version of the input tensor """# to carry out the multiplication, the mask needs to have been computed,# so the pruning method must know what tensor it's operating onassertself._tensor_nameisnotNone,f"Module {module} has to be pruned"# this gets set in apply()mask=getattr(module,self._tensor_name+"_mask")orig=getattr(module,self._tensor_name+"_orig")pruned_tensor=mask.to(dtype=orig.dtype)*origreturnpruned_tensor
[docs]@classmethoddefapply(cls,module,name,*args,importance_scores=None,**kwargs):r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. args: arguments passed on to a subclass of :class:`BasePruningMethod` importance_scores (torch.Tensor): tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the parameter will be used in its place. kwargs: keyword arguments passed on to a subclass of a :class:`BasePruningMethod` """def_get_composite_method(cls,module,name,*args,**kwargs):# Check if a pruning method has already been applied to# `module[name]`. If so, store that in `old_method`.old_method=Nonefound=0# there should technically be only 1 hook with hook.name == name# assert this using `found`hooks_to_remove=[]fork,hookinmodule._forward_pre_hooks.items():# if it exists, take existing thing, remove hook, then# go through normal thingifisinstance(hook,BasePruningMethod)andhook._tensor_name==name:old_method=hookhooks_to_remove.append(k)found+=1assert(found<=1),f"Avoid adding multiple pruning hooks to the\ same tensor {name} of module {module}. Use a PruningContainer."forkinhooks_to_remove:delmodule._forward_pre_hooks[k]# Apply the new pruning method, either from scratch or on top of# the previous one.method=cls(*args,**kwargs)# new pruning# Have the pruning method remember what tensor it's been applied tomethod._tensor_name=name# combine `methods` with `old_method`, if `old_method` existsifold_methodisnotNone:# meaning that there was a hook# if the hook is already a pruning container, just add the# new pruning method to the containerifisinstance(old_method,PruningContainer):old_method.add_pruning_method(method)method=old_method# rename old_method --> method# if the hook is simply a single pruning method, create a# container, add the old pruning method and the new oneelifisinstance(old_method,BasePruningMethod):container=PruningContainer(old_method)# Have the pruning method remember the name of its tensor# setattr(container, '_tensor_name', name)container.add_pruning_method(method)method=container# rename container --> methodreturnmethodmethod=_get_composite_method(cls,module,name,*args,**kwargs)# at this point we have no forward_pre_hooks but we could have an# active reparametrization of the tensor if another pruning method# had been applied (in which case `method` would be a PruningContainer# and not a simple pruning method).# Pruning is to be applied to the module's tensor named `name`,# starting from the state it is found in prior to this iteration of# pruning. The pruning mask is calculated based on importances scores.orig=getattr(module,name)ifimportance_scoresisnotNone:assert(importance_scores.shape==orig.shape),f"importance_scores should have the same shape as parameter {name} of {module}"else:importance_scores=orig# If this is the first time pruning is applied, take care of moving# the original tensor to a new parameter called name + '_orig' and# and deleting the original parameterifnotisinstance(method,PruningContainer):# copy `module[name]` to `module[name + '_orig']`module.register_parameter(name+"_orig",orig)# temporarily delete `module[name]`delmodule._parameters[name]default_mask=torch.ones_like(orig)# temp# If this is not the first time pruning is applied, all of the above# has been done before in a previous pruning iteration, so we're good# to goelse:default_mask=(getattr(module,name+"_mask").detach().clone(memory_format=torch.contiguous_format))# Use try/except because if anything goes wrong with the mask# computation etc., you'd want to roll back.try:# get the final mask, computed according to the specific methodmask=method.compute_mask(importance_scores,default_mask=default_mask)# reparameterize by saving mask to `module[name + '_mask']`...module.register_buffer(name+"_mask",mask)# ... and the new pruned tensor to `module[name]`setattr(module,name,method.apply_mask(module))# associate the pruning method to the module via a hook to# compute the function before every forward() (compile by run)module.register_forward_pre_hook(method)exceptExceptionase:ifnotisinstance(method,PruningContainer):orig=getattr(module,name+"_orig")module.register_parameter(name,orig)delmodule._parameters[name+"_orig"]raiseereturnmethod
[docs]defprune(self,t,default_mask=None,importance_scores=None):r"""Computes and returns a pruned version of input tensor ``t`` according to the pruning rule specified in :meth:`compute_mask`. Args: t (torch.Tensor): tensor to prune (of same dimensions as ``default_mask``). importance_scores (torch.Tensor): tensor of importance scores (of same shape as ``t``) used to compute mask for pruning ``t``. The values in this tensor indicate the importance of the corresponding elements in the ``t`` that is being pruned. If unspecified or None, the tensor ``t`` will be used in its place. default_mask (torch.Tensor, optional): mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, default to a mask of ones. Returns: pruned version of tensor ``t``. """ifimportance_scoresisnotNone:assert(importance_scores.shape==t.shape),"importance_scores should have the same shape as tensor t"else:importance_scores=tdefault_mask=default_maskifdefault_maskisnotNoneelsetorch.ones_like(t)returnt*self.compute_mask(importance_scores,default_mask=default_mask)
[docs]defremove(self,module):r"""Removes the pruning reparameterization from a module. The pruned parameter named ``name`` remains permanently pruned, and the parameter named ``name+'_orig'`` is removed from the parameter list. Similarly, the buffer named ``name+'_mask'`` is removed from the buffers. Note: Pruning itself is NOT undone or reversed! """# before removing pruning from a tensor, it has to have been appliedassert(self._tensor_nameisnotNone),f"Module {module} has to be pruned before pruning can be removed"# this gets set in apply()# to update module[name] to latest trained weightsweight=self.apply_mask(module)# masked weights# delete and resetifhasattr(module,self._tensor_name):delattr(module,self._tensor_name)orig=module._parameters[self._tensor_name+"_orig"]orig.data=weight.datadelmodule._parameters[self._tensor_name+"_orig"]delmodule._buffers[self._tensor_name+"_mask"]setattr(module,self._tensor_name,orig)
[docs]classPruningContainer(BasePruningMethod):"""Container holding a sequence of pruning methods for iterative pruning. Keeps track of the order in which pruning methods are applied and handles combining successive pruning calls. Accepts as argument an instance of a BasePruningMethod or an iterable of them. """def__init__(self,*args):self._pruning_methods:Tuple[BasePruningMethod,...]=tuple()ifnotisinstance(args,Iterable):# only 1 itemself._tensor_name=args._tensor_nameself.add_pruning_method(args)eliflen(args)==1:# only 1 item in a tupleself._tensor_name=args[0]._tensor_nameself.add_pruning_method(args[0])else:# manual construction from list or other iterable (or no args)formethodinargs:self.add_pruning_method(method)
[docs]defadd_pruning_method(self,method):r"""Adds a child pruning ``method`` to the container. Args: method (subclass of BasePruningMethod): child pruning method to be added to the container. """# check that we're adding a pruning method to the containerifnotisinstance(method,BasePruningMethod)andmethodisnotNone:raiseTypeError(f"{type(method)} is not a BasePruningMethod subclass")elifmethodisnotNoneandself._tensor_name!=method._tensor_name:raiseValueError("Can only add pruning methods acting on "f"the parameter named '{self._tensor_name}' to PruningContainer {self}."+f" Found '{method._tensor_name}'")# if all checks passed, add to _pruning_methods tupleself._pruning_methods+=(method,)# type: ignore[operator]
[docs]defcompute_mask(self,t,default_mask):r"""Applies the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``. The new partial mask should be computed on the entries or channels that were not zeroed out by the ``default_mask``. Which portions of the tensor ``t`` the new mask will be calculated from depends on the ``PRUNING_TYPE`` (handled by the type handler): * for 'unstructured', the mask will be computed from the raveled list of nonmasked entries; * for 'structured', the mask will be computed from the nonmasked channels in the tensor; * for 'global', the mask will be computed across all entries. Args: t (torch.Tensor): tensor representing the parameter to prune (of same dimensions as ``default_mask``). default_mask (torch.Tensor): mask from previous pruning iteration. Returns: mask (torch.Tensor): new mask that combines the effects of the ``default_mask`` and the new mask from the current pruning ``method`` (of same dimensions as ``default_mask`` and ``t``). """def_combine_masks(method,t,mask):r""" Args: method (a BasePruningMethod subclass): pruning method currently being applied. t (torch.Tensor): tensor representing the parameter to prune (of same dimensions as mask). mask (torch.Tensor): mask from previous pruning iteration Returns: new_mask (torch.Tensor): new mask that combines the effects of the old mask and the new mask from the current pruning method (of same dimensions as mask and t). """new_mask=mask# start off from existing masknew_mask=new_mask.to(dtype=t.dtype)# compute a slice of t onto which the new pruning method will operateifmethod.PRUNING_TYPE=="unstructured":# prune entries of t where the mask is 1slc=mask==1# for struct pruning, exclude channels that have already been# entirely prunedelifmethod.PRUNING_TYPE=="structured":ifnothasattr(method,"dim"):raiseAttributeError("Pruning methods of PRUNING_TYPE "'"structured" need to have the attribute `dim` defined.')# find the channels to keep by removing the ones that have been# zeroed out already (i.e. where sum(entries) == 0)n_dims=t.dim()# "is this a 2D tensor? 3D? ..."dim=method.dim# convert negative indexingifdim<0:dim=n_dims+dim# if dim is still negative after subtracting it from n_dimsifdim<0:raiseIndexError(f"Index is out of bounds for tensor with dimensions {n_dims}")# find channels along dim = dim that aren't already tots 0ed outkeep_channel=mask.sum(dim=[dfordinrange(n_dims)ifd!=dim])!=0# create slice to identify what to pruneslc=[slice(None)]*n_dimsslc[dim]=keep_channelelifmethod.PRUNING_TYPE=="global":n_dims=len(t.shape)# "is this a 2D tensor? 3D? ..."slc=[slice(None)]*n_dimselse:raiseValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}")# compute the new mask on the unpruned slice of the tensor tpartial_mask=method.compute_mask(t[slc],default_mask=mask[slc])new_mask[slc]=partial_mask.to(dtype=new_mask.dtype)returnnew_maskmethod=self._pruning_methods[-1]mask=_combine_masks(method,t,default_mask)returnmask
[docs]classIdentity(BasePruningMethod):r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones. """PRUNING_TYPE="unstructured"defcompute_mask(self,t,default_mask):mask=default_maskreturnmask
[docs]@classmethoddefapply(cls,module,name):r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. """returnsuper().apply(module,name)
[docs]classRandomUnstructured(BasePruningMethod):r"""Prune (currently unpruned) units in a tensor at random. Args: name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. """PRUNING_TYPE="unstructured"def__init__(self,amount):# Check range of validity of pruning amount_validate_pruning_amount_init(amount)self.amount=amountdefcompute_mask(self,t,default_mask):# Check that the amount of units to prune is not > than the number of# parameters in ttensor_size=t.nelement()# Compute number of units to prune: amount if int,# else amount * tensor_sizenparams_toprune=_compute_nparams_toprune(self.amount,tensor_size)# This should raise an error if the number of units to prune is larger# than the number of units in the tensor_validate_pruning_amount(nparams_toprune,tensor_size)mask=default_mask.clone(memory_format=torch.contiguous_format)ifnparams_toprune!=0:# k=0 not supported by torch.kthvalueprob=torch.rand_like(t)topk=torch.topk(prob.view(-1),k=nparams_toprune)mask.view(-1)[topk.indices]=0returnmask
[docs]@classmethoddefapply(cls,module,name,amount):r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. """returnsuper().apply(module,name,amount=amount)
[docs]classL1Unstructured(BasePruningMethod):r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm. Args: amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. """PRUNING_TYPE="unstructured"def__init__(self,amount):# Check range of validity of pruning amount_validate_pruning_amount_init(amount)self.amount=amountdefcompute_mask(self,t,default_mask):# Check that the amount of units to prune is not > than the number of# parameters in ttensor_size=t.nelement()# Compute number of units to prune: amount if int,# else amount * tensor_sizenparams_toprune=_compute_nparams_toprune(self.amount,tensor_size)# This should raise an error if the number of units to prune is larger# than the number of units in the tensor_validate_pruning_amount(nparams_toprune,tensor_size)mask=default_mask.clone(memory_format=torch.contiguous_format)ifnparams_toprune!=0:# k=0 not supported by torch.kthvalue# largest=True --> top k; largest=False --> bottom k# Prune the smallest ktopk=torch.topk(torch.abs(t).view(-1),k=nparams_toprune,largest=False)# topk will have .indices and .valuesmask.view(-1)[topk.indices]=0returnmask
[docs]@classmethoddefapply(cls,module,name,amount,importance_scores=None):r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. importance_scores (torch.Tensor): tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the module parameter will be used in its place. """returnsuper().apply(module,name,amount=amount,importance_scores=importance_scores)
[docs]classRandomStructured(BasePruningMethod):r"""Prune entire (currently unpruned) channels in a tensor at random. Args: amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. dim (int, optional): index of the dim along which we define channels to prune. Default: -1. """PRUNING_TYPE="structured"def__init__(self,amount,dim=-1):# Check range of validity of amount_validate_pruning_amount_init(amount)self.amount=amountself.dim=dim
[docs]defcompute_mask(self,t,default_mask):r"""Computes and returns a mask for the input tensor ``t``. Starting from a base ``default_mask`` (which should be a mask of ones if the tensor has not been pruned yet), generate a random mask to apply on top of the ``default_mask`` by randomly zeroing out channels along the specified dim of the tensor. Args: t (torch.Tensor): tensor representing the parameter to prune default_mask (torch.Tensor): Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as ``t``. Returns: mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` Raises: IndexError: if ``self.dim >= len(t.shape)`` """# Check that tensor has structure (i.e. more than 1 dimension) such# that the concept of "channels" makes sense_validate_structured_pruning(t)# Check that self.dim is a valid dim to index t, else raise IndexError_validate_pruning_dim(t,self.dim)# Check that the amount of channels to prune is not > than the number of# channels in t along the dim to prunetensor_size=t.shape[self.dim]# Compute number of units to prune: amount if int,# else amount * tensor_sizenparams_toprune=_compute_nparams_toprune(self.amount,tensor_size)# This should raise an error if the number of units to prune is larger# than the number of units in the tensor_validate_pruning_amount(nparams_toprune,tensor_size)# Compute binary mask by initializing it to all 0s and then filling in# 1s wherever topk.indices indicates, along self.dim.# mask has the same shape as tensor tdefmake_mask(t,dim,nchannels,nchannels_toprune):# generate a random number in [0, 1] to associate to each channelprob=torch.rand(nchannels)# generate mask for each channel by 0ing out the channels that# got assigned the k = nchannels_toprune lowest values in probthreshold=torch.kthvalue(prob,k=nchannels_toprune).valueschannel_mask=prob>thresholdmask=torch.zeros_like(t)slc=[slice(None)]*len(t.shape)slc[dim]=channel_maskmask[slc]=1returnmaskifnparams_toprune==0:# k=0 not supported by torch.kthvaluemask=default_maskelse:# apply the new structured mask on top of prior (potentially# unstructured) maskmask=make_mask(t,self.dim,tensor_size,nparams_toprune)mask*=default_mask.to(dtype=mask.dtype)returnmask
[docs]@classmethoddefapply(cls,module,name,amount,dim=-1):r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. dim (int, optional): index of the dim along which we define channels to prune. Default: -1. """returnsuper().apply(module,name,amount=amount,dim=dim)
[docs]classLnStructured(BasePruningMethod):r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm. Args: amount (int or float): quantity of channels to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid entries for argument ``p`` in :func:`torch.norm`. dim (int, optional): index of the dim along which we define channels to prune. Default: -1. """PRUNING_TYPE="structured"def__init__(self,amount,n,dim=-1):# Check range of validity of amount_validate_pruning_amount_init(amount)self.amount=amountself.n=nself.dim=dim
[docs]defcompute_mask(self,t,default_mask):r"""Computes and returns a mask for the input tensor ``t``. Starting from a base ``default_mask`` (which should be a mask of ones if the tensor has not been pruned yet), generate a mask to apply on top of the ``default_mask`` by zeroing out the channels along the specified dim with the lowest L\ ``n``-norm. Args: t (torch.Tensor): tensor representing the parameter to prune default_mask (torch.Tensor): Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as ``t``. Returns: mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` Raises: IndexError: if ``self.dim >= len(t.shape)`` """# Check that tensor has structure (i.e. more than 1 dimension) such# that the concept of "channels" makes sense_validate_structured_pruning(t)# Check that self.dim is a valid dim to index t, else raise IndexError_validate_pruning_dim(t,self.dim)# Check that the amount of channels to prune is not > than the number of# channels in t along the dim to prunetensor_size=t.shape[self.dim]# Compute number of units to prune: amount if int,# else amount * tensor_sizenparams_toprune=_compute_nparams_toprune(self.amount,tensor_size)nparams_tokeep=tensor_size-nparams_toprune# This should raise an error if the number of units to prune is larger# than the number of units in the tensor_validate_pruning_amount(nparams_toprune,tensor_size)# Structured pruning prunes entire channels so we need to know the# L_n norm along each channel to then find the topk based on this# metricnorm=_compute_norm(t,self.n,self.dim)# largest=True --> top k; largest=False --> bottom k# Keep the largest k channels along dim=self.dimtopk=torch.topk(norm,k=nparams_tokeep,largest=True)# topk will have .indices and .values# Compute binary mask by initializing it to all 0s and then filling in# 1s wherever topk.indices indicates, along self.dim.# mask has the same shape as tensor tdefmake_mask(t,dim,indices):# init mask to 0mask=torch.zeros_like(t)# e.g.: slc = [None, None, None], if len(t.shape) = 3slc=[slice(None)]*len(t.shape)# replace a None at position=dim with indices# e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3]slc[dim]=indices# use slc to slice mask and replace all its entries with 1s# e.g.: mask[:, :, [0, 2, 3]] = 1mask[slc]=1returnmaskifnparams_toprune==0:# k=0 not supported by torch.kthvaluemask=default_maskelse:mask=make_mask(t,self.dim,topk.indices)mask*=default_mask.to(dtype=mask.dtype)returnmask
[docs]@classmethoddefapply(cls,module,name,amount,n,dim,importance_scores=None):r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid entries for argument ``p`` in :func:`torch.norm`. dim (int): index of the dim along which we define channels to prune. importance_scores (torch.Tensor): tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the module parameter will be used in its place. """returnsuper().apply(module,name,amount=amount,n=n,dim=dim,importance_scores=importance_scores,)
[docs]@classmethoddefapply(cls,module,name,mask):r"""Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. """returnsuper().apply(module,name,mask=mask)
[docs]defidentity(module,name):r"""Applies pruning reparametrization to the tensor corresponding to the parameter called ``name`` in ``module`` without actually pruning any units. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Note: The mask is a tensor of ones. Args: module (nn.Module): module containing the tensor to prune. name (str): parameter name within ``module`` on which pruning will act. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> # xdoctest: +SKIP >>> m = prune.identity(nn.Linear(2, 3), 'bias') >>> print(m.bias_mask) tensor([1., 1., 1.]) """Identity.apply(module,name)returnmodule
[docs]defrandom_unstructured(module,name,amount):r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified ``amount`` of (currently unpruned) units selected at random. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> # xdoctest: +SKIP >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1) >>> torch.sum(m.weight_mask == 0) tensor(1) """RandomUnstructured.apply(module,name,amount)returnmodule
[docs]defl1_unstructured(module,name,amount,importance_scores=None):r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified `amount` of (currently unpruned) units with the lowest L1-norm. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. importance_scores (torch.Tensor): tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the module parameter will be used in its place. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> # xdoctest: +SKIP >>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2) >>> m.state_dict().keys() odict_keys(['bias', 'weight_orig', 'weight_mask']) """L1Unstructured.apply(module,name,amount=amount,importance_scores=importance_scores)returnmodule
[docs]defrandom_structured(module,name,amount,dim):r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified ``amount`` of (currently unpruned) channels along the specified ``dim`` selected at random. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. dim (int): index of the dim along which we define channels to prune. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> # xdoctest: +SKIP >>> m = prune.random_structured( ... nn.Linear(5, 3), 'weight', amount=3, dim=1 ... ) >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0)) >>> print(columns_pruned) 3 """RandomStructured.apply(module,name,amount,dim)returnmodule
[docs]defln_structured(module,name,amount,n,dim,importance_scores=None):r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by removing the specified ``amount`` of (currently unpruned) channels along the specified ``dim`` with the lowest L\ ``n``-norm. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. amount (int or float): quantity of parameters to prune. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid entries for argument ``p`` in :func:`torch.norm`. dim (int): index of the dim along which we define channels to prune. importance_scores (torch.Tensor): tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the module parameter will be used in its place. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> from torch.nn.utils import prune >>> m = prune.ln_structured( ... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') ... ) """LnStructured.apply(module,name,amount,n,dim,importance_scores=importance_scores)returnmodule
[docs]defglobal_unstructured(parameters,pruning_method,importance_scores=None,**kwargs):r""" Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. Modifies modules in place by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: parameters (Iterable of (module, name) tuples): parameters of the model to prune in a global fashion, i.e. by aggregating all weights prior to deciding which ones to prune. module must be of type :class:`nn.Module`, and name must be a string. pruning_method (function): a valid pruning function from this module, or a custom one implemented by the user that satisfies the implementation guidelines and has ``PRUNING_TYPE='unstructured'``. importance_scores (dict): a dictionary mapping (module, name) tuples to the corresponding parameter's importance scores tensor. The tensor should be the same shape as the parameter, and is used for computing mask for pruning. If unspecified or None, the parameter will be used in place of its importance scores. kwargs: other keyword arguments such as: amount (int or float): quantity of parameters to prune across the specified parameters. If ``float``, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If ``int``, it represents the absolute number of parameters to prune. Raises: TypeError: if ``PRUNING_TYPE != 'unstructured'`` Note: Since global structured pruning doesn't make much sense unless the norm is normalized by the size of the parameter, we now limit the scope of global pruning to unstructured methods. Examples: >>> from torch.nn.utils import prune >>> from collections import OrderedDict >>> net = nn.Sequential(OrderedDict([ ... ('first', nn.Linear(10, 4)), ... ('second', nn.Linear(4, 1)), ... ])) >>> parameters_to_prune = ( ... (net.first, 'weight'), ... (net.second, 'weight'), ... ) >>> prune.global_unstructured( ... parameters_to_prune, ... pruning_method=prune.L1Unstructured, ... amount=10, ... ) >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0)) tensor(10) """# ensure parameters is a list or generator of tuplesifnotisinstance(parameters,Iterable):raiseTypeError("global_unstructured(): parameters is not an Iterable")importance_scores=importance_scoresifimportance_scoresisnotNoneelse{}ifnotisinstance(importance_scores,dict):raiseTypeError("global_unstructured(): importance_scores must be of type dict")# flatten importance scores to consider them all at once in global pruningrelevant_importance_scores=torch.nn.utils.parameters_to_vector([importance_scores.get((module,name),getattr(module,name))for(module,name)inparameters])# similarly, flatten the masks (if they exist), or use a flattened vector# of 1s of the same dimensions as tdefault_mask=torch.nn.utils.parameters_to_vector([getattr(module,name+"_mask",torch.ones_like(getattr(module,name)))for(module,name)inparameters])# use the canonical pruning methods to compute the new mask, even if the# parameter is now a flattened out version of `parameters`container=PruningContainer()container._tensor_name="temp"# to make it match that of `method`method=pruning_method(**kwargs)method._tensor_name="temp"# to make it match that of `container`ifmethod.PRUNING_TYPE!="unstructured":raiseTypeError('Only "unstructured" PRUNING_TYPE supported for 'f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}")container.add_pruning_method(method)# use the `compute_mask` method from `PruningContainer` to combine the# mask computed by the new method with the pre-existing maskfinal_mask=container.compute_mask(relevant_importance_scores,default_mask)# Pointer for slicing the mask to match the shape of each parameterpointer=0formodule,nameinparameters:param=getattr(module,name)# The length of the parameternum_param=param.numel()# Slice the mask, reshape itparam_mask=final_mask[pointer:pointer+num_param].view_as(param)# Assign the correct pre-computed mask to each parameter and add it# to the forward_pre_hooks like any other pruning methodcustom_from_mask(module,name,mask=param_mask)# Increment the pointer to continue slicing the final_maskpointer+=num_param
[docs]defcustom_from_mask(module,name,mask):r"""Prunes tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called ``name+'_mask'`` corresponding to the binary mask applied to the parameter ``name`` by the pruning method. 2) replacing the parameter ``name`` by its pruned version, while the original (unpruned) parameter is stored in a new parameter named ``name+'_orig'``. Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. mask (Tensor): binary mask to be applied to the parameter. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> from torch.nn.utils import prune >>> m = prune.custom_from_mask( ... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0]) ... ) >>> print(m.bias_mask) tensor([0., 1., 0.]) """CustomFromMask.apply(module,name,mask)returnmodule
[docs]defremove(module,name):r"""Removes the pruning reparameterization from a module and the pruning method from the forward hook. The pruned parameter named ``name`` remains permanently pruned, and the parameter named ``name+'_orig'`` is removed from the parameter list. Similarly, the buffer named ``name+'_mask'`` is removed from the buffers. Note: Pruning itself is NOT undone or reversed! Args: module (nn.Module): module containing the tensor to prune name (str): parameter name within ``module`` on which pruning will act. Examples: >>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2) >>> m = remove(m, name='weight') """fork,hookinmodule._forward_pre_hooks.items():ifisinstance(hook,BasePruningMethod)andhook._tensor_name==name:hook.remove(module)delmodule._forward_pre_hooks[k]returnmoduleraiseValueError(f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed")
[docs]defis_pruned(module):r"""Check whether ``module`` is pruned by looking for ``forward_pre_hooks`` in its modules that inherit from the :class:`BasePruningMethod`. Args: module (nn.Module): object that is either pruned or unpruned Returns: binary answer to whether ``module`` is pruned. Examples: >>> from torch.nn.utils import prune >>> m = nn.Linear(5, 7) >>> print(prune.is_pruned(m)) False >>> prune.random_unstructured(m, name='weight', amount=0.2) >>> print(prune.is_pruned(m)) True """for_,submoduleinmodule.named_modules():forhookinsubmodule._forward_pre_hooks.values():ifisinstance(hook,BasePruningMethod):returnTruereturnFalse
def_validate_pruning_amount_init(amount):r"""Validation helper to check the range of amount at init. Args: amount (int or float): quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. Raises: ValueError: if amount is a float not in [0, 1], or if it's a negative integer. TypeError: if amount is neither a float nor an integer. Note: This does not take into account the number of parameters in the tensor to be pruned, which is known only at prune. """ifnotisinstance(amount,numbers.Real):raiseTypeError(f"Invalid type for amount: {amount}. Must be int or float.")if(isinstance(amount,numbers.Integral)andamount<0)or(notisinstance(amount,numbers.Integral)# so it's a floatand(float(amount)>1.0orfloat(amount)<0.0)):raiseValueError(f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer")def_validate_pruning_amount(amount,tensor_size):r"""Validation helper to check that the amount of parameters to prune is meaningful wrt to the size of the data (`tensor_size`). Args: amount (int or float): quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. tensor_size (int): absolute number of parameters in the tensor to prune. """# TODO: consider removing this check and allowing users to specify# a number of units to prune that is greater than the number of units# left to prune. In this case, the tensor will just be fully pruned.ifisinstance(amount,numbers.Integral)andamount>tensor_size:raiseValueError(f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}")def_validate_structured_pruning(t):r"""Validation helper to check that the tensor to be pruned is multi- dimensional, such that the concept of "channels" is well-defined. Args: t (torch.Tensor): tensor representing the parameter to prune Raises: ValueError: if the tensor `t` is not at least 2D. """shape=t.shapeiflen(shape)<=1:raiseValueError("Structured pruning can only be applied to ""multidimensional tensors. Found tensor of shape "f"{shape} with {len(shape)} dims")def_compute_nparams_toprune(amount,tensor_size):r"""Since amount can be expressed either in absolute value or as a percentage of the number of units/channels in a tensor, this utility function converts the percentage to absolute value to standardize the handling of pruning. Args: amount (int or float): quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. tensor_size (int): absolute number of parameters in the tensor to prune. Returns: int: the number of units to prune in the tensor """# incorrect type already checked in _validate_pruning_amount_initifisinstance(amount,numbers.Integral):returnamountelse:returnround(amount*tensor_size)def_validate_pruning_dim(t,dim):r""" Args: t (torch.Tensor): tensor representing the parameter to prune dim (int): index of the dim along which we define channels to prune """ifdim>=t.dim():raiseIndexError(f"Invalid index {dim} for tensor of size {t.shape}")def_compute_norm(t,n,dim):r"""Compute the L_n-norm across all entries in tensor `t` along all dimension except for the one identified by dim. Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim), then norm will have Size [4], and each entry will represent the `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels. Args: t (torch.Tensor): tensor representing the parameter to prune n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid entries for argument p in torch.norm dim (int): dim identifying the channels to prune Returns: norm (torch.Tensor): L_n norm computed across all dimensions except for `dim`. By construction, `norm.shape = t.shape[-1]`. """# dims = all axes, except for the one identified by `dim`dims=list(range(t.dim()))# convert negative indexingifdim<0:dim=dims[dim]dims.remove(dim)norm=torch.norm(t,p=n,dim=dims)returnnorm
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.