[docs]@contextmanagerdefcached():r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`. The value of the parametrized objects is computed and cached the first time they are required when this context manager is active. The cached values are discarded when leaving the context manager. This is useful when using a parametrized parameter more than once in the forward pass. An example of this is when parametrizing the recurrent kernel of an RNN or when sharing weights. The simplest way to activate the cache is by wrapping the forward pass of the neural network .. code-block:: python import torch.nn.utils.parametrize as P ... with P.cached(): output = model(inputs) in training and evaluation. One may also wrap the parts of the modules that use several times the parametrized tensors. For example, the loop of an RNN with a parametrized recurrent kernel: .. code-block:: python with P.cached(): for x in xs: out_rnn = self.rnn_cell(x, out_rnn) """global_cacheglobal_cache_enabled_cache_enabled+=1try:yieldfinally:_cache_enabled-=1ifnot_cache_enabled:_cache={}
[docs]classParametrizationList(ModuleList):r"""A sequential container that holds and manages the ``original`` or ``original0``, ``original1``, ... parameters or buffers of a parametrized :class:`torch.nn.Module`. It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` has been parametrized with :func:`register_parametrization`. If the first registered parametrization has a ``right_inverse`` that returns one tensor or does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity), it will hold the tensor under the name ``original``. If it has a ``right_inverse`` that returns more than one tensor, these will be registered as ``original0``, ``original1``, ... .. warning:: This class is used internally by :func:`register_parametrization`. It is documented here for completeness. It shall not be instantiated by the user. Args: modules (sequence): sequence of modules representing the parametrizations original (Parameter or Tensor): parameter or buffer that is parametrized unsafe (bool): a boolean flag that denotes whether the parametrization may change the dtype and shape of the tensor. Default: `False` Warning: the parametrization is not checked for consistency upon registration. Enable this flag at your own risk. """original:Tensorunsafe:booldef__init__(self,modules:Sequence[Module],original:Union[Tensor,Parameter],unsafe:bool=False)->None:# We require this because we need to treat differently the first parametrization# This should never throw, unless this class is used from the outsideiflen(modules)==0:raiseValueError("ParametrizationList requires one or more modules.")super().__init__(modules)self.unsafe=unsafe# In plain words:# module.weight must keep its dtype and shape.# Furthermore, if there is no right_inverse or the right_inverse returns a tensor,# this should be of the same dtype as the original tensor## We check that the following invariants hold:# X = module.weight# Y = param.right_inverse(X)# assert isinstance(Y, Tensor) or# (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y))# Z = param(Y) if isinstance(Y, Tensor) else param(*Y)# # Consistency checks# assert X.dtype == Z.dtype and X.shape == Z.shape# # If it has one input, this allows to be able to use set_ to be able to# # move data to/from the original tensor without changing its id (which is what the# # optimizer uses to track parameters)# if isinstance(Y, Tensor)# assert X.dtype == Y.dtype# Below we use original = X, new = Yoriginal_shape=original.shapeoriginal_dtype=original.dtype# Compute newwithtorch.no_grad():new=originalformoduleinreversed(self):# type: ignore[call-overload]ifhasattr(module,"right_inverse"):try:new=module.right_inverse(new)exceptNotImplementedError:pass# else, or if it throws, we assume that right_inverse is the identityifnotisinstance(new,Tensor)andnotisinstance(new,collections.abc.Sequence):raiseValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "f"Got {type(new).__name__}")# Set the number of original tensorsself.is_tensor=isinstance(new,Tensor)self.ntensors=1ifself.is_tensorelselen(new)# Register the tensor(s)ifself.is_tensor:iforiginal.dtype!=new.dtype:raiseValueError("When `right_inverse` outputs one tensor, it may not change the dtype.\n"f"original.dtype: {original.dtype}\n"f"right_inverse(original).dtype: {new.dtype}")# Set the original to original so that the user does not need to re-register the parameter# manually in the optimiserwithtorch.no_grad():original.set_(new)# type: ignore[call-overload]_register_parameter_or_buffer(self,"original",original)else:fori,originaliinenumerate(new):ifnotisinstance(originali,Tensor):raiseValueError("'right_inverse' must return a Tensor or a Sequence of tensors ""(list, tuple...). "f"Got element {i} of the sequence with type {type(originali).__name__}.")# If the original tensor was a Parameter that required grad, we expect the user to# add the new parameters to the optimizer after registering the parametrization# (this is documented)ifisinstance(original,Parameter):originali=Parameter(originali)originali.requires_grad_(original.requires_grad)_register_parameter_or_buffer(self,f"original{i}",originali)ifnotself.unsafe:# Consistency checks:# Since f : A -> B, right_inverse : B -> A, Z and original should live in B# Z = forward(right_inverse(original))Z=self()ifnotisinstance(Z,Tensor):raiseValueError(f"A parametrization must return a tensor. Got {type(Z).__name__}.")ifZ.dtype!=original_dtype:raiseValueError("Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n"f"unparametrized dtype: {original_dtype}\n"f"parametrized dtype: {Z.dtype}")ifZ.shape!=original_shape:raiseValueError("Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n"f"unparametrized shape: {original_shape}\n"f"parametrized shape: {Z.shape}")
[docs]defright_inverse(self,value:Tensor)->None:r"""Calls the methods ``right_inverse`` (see :func:`register_parametrization`) of the parametrizations in the inverse order they were registered in. Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor or in ``self.original0``, ``self.original1``, ... if it outputs several. Args: value (Tensor): Value to which initialize the module """# All the exceptions in this function should almost never throw.# They could throw if, for example, right_inverse function returns a different# dtype when given a different input, which should most likely be caused by a# bug in the user's codewithtorch.no_grad():# See https://github.com/pytorch/pytorch/issues/53103formoduleinreversed(self):# type: ignore[call-overload]ifhasattr(module,"right_inverse"):value=module.right_inverse(value)else:raiseRuntimeError(f"parametrization {type(module).__name__} does not implement ""right_inverse.")ifself.is_tensor:# These exceptions should only throw when a right_inverse function does not# return the same dtype for every input, which should most likely be caused by a bugifnotisinstance(value,Tensor):raiseValueError(f"`right_inverse` should return a tensor. Got {type(value).__name__}")ifvalue.dtype!=self.original.dtype:raiseValueError(f"The tensor returned by `right_inverse` has dtype {value.dtype} "f"while `original` has dtype {self.original.dtype}")# We know that the result is going to have the same dtypeself.original.set_(value)# type: ignore[call-overload]else:ifnotisinstance(value,collections.abc.Sequence):raiseValueError("'right_inverse' must return a sequence of tensors. "f"Got {type(value).__name__}.")iflen(value)!=self.ntensors:raiseValueError("'right_inverse' must return a sequence of tensors of length "f"{self.ntensors}. Got a sequence of length {len(value)}.")fori,tensorinenumerate(value):original_i=getattr(self,f"original{i}")ifnotisinstance(tensor,Tensor):raiseValueError(f"`right_inverse` must return a sequence of tensors. "f"Got element {i} of type {type(tensor).__name__}")iforiginal_i.dtype!=tensor.dtype:raiseValueError(f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} "f"while `original{i}` has dtype {original_i.dtype}")original_i.set_(tensor)
defforward(self)->Tensor:iftorch.jit.is_scripting():raiseRuntimeError('Parametrization is not working with scripting.')# Unpack the originals for the first parametrizationifself.is_tensor:x=self[0](self.original)else:originals=(getattr(self,f"original{i}")foriinrange(self.ntensors))x=self[0](*originals)# It's not possible to call self[1:] here, so we have to be a bit more cryptic# Also we want to skip all non-integer keyscurr_idx=1whilehasattr(self,str(curr_idx)):x=self[curr_idx](x)curr_idx+=1returnx
def_inject_new_class(module:Module)->None:r"""Sets up a module to be parametrized. This works by substituting the class of the module by a class that extends it to be able to inject a property Args: module (nn.Module): module into which to inject the property """cls=module.__class__defdefault_deepcopy(self,memo):# Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class.obj=memo.get(id(self),None)ifobjisnotNone:returnobjreplica=self.__new__(self.__class__)memo[id(self)]=replicareplica.__dict__=deepcopy(self.__dict__,memo)# Also save all slots if they exist.slots_to_save=copyreg._slotnames(self.__class__)# type: ignore[attr-defined]forslotinslots_to_save:ifhasattr(self,slot):setattr(replica,slot,deepcopy(getattr(self,slot),memo))returnreplicadefgetstate(self):raiseRuntimeError("Serialization of parametrized modules is only ""supported through state_dict(). See:\n""https://pytorch.org/tutorials/beginner/saving_loading_models.html""#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training")dct={"__getstate__":getstate}# We don't allow serialization of parametrized modules but should still allow deepcopying.# Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists.ifnothasattr(cls,"__deepcopy__"):dct["__deepcopy__"]=default_deepcopy# type: ignore[assignment]param_cls=type(f"Parametrized{cls.__name__}",(cls,),dct,)module.__class__=param_clsdef_inject_property(module:Module,tensor_name:str)->None:r"""Injects a property into module[tensor_name]. It assumes that the class in the module has already been modified from its original one using _inject_new_class and that the tensor under :attr:`tensor_name` has already been moved out Args: module (nn.Module): module into which to inject the property tensor_name (str): name of the name of the property to create """# We check the precondition.# This should never fire if register_parametrization is correctly implementedassertnothasattr(module,tensor_name)@torch.jit.unuseddefget_cached_parametrization(parametrization)->Tensor:global_cachekey=(id(module),tensor_name)tensor=_cache.get(key)iftensorisNone:tensor=parametrization()_cache[key]=tensorreturntensordefget_parametrized(self)->Tensor:iftorch.jit.is_scripting():raiseRuntimeError('Parametrization is not working with scripting.')parametrization=self.parametrizations[tensor_name]if_cache_enabled:iftorch.jit.is_scripting():# ScriptingraiseRuntimeError('Caching is not implemented for scripting. ''Either disable caching or avoid scripting.')eliftorch._C._get_tracing_state()isnotNone:# TracingraiseRuntimeError('Cannot trace a model while caching parametrizations.')else:returnget_cached_parametrization(parametrization)else:# If caching is not active, this function just evaluates the parametrizationreturnparametrization()defset_original(self,value:Tensor)->None:iftorch.jit.is_scripting():raiseRuntimeError('Parametrization is not working with scripting.')self.parametrizations[tensor_name].right_inverse(value)setattr(module.__class__,tensor_name,property(get_parametrized,set_original))
[docs]defregister_parametrization(module:Module,tensor_name:str,parametrization:Module,*,unsafe:bool=False,)->Module:r"""Adds a parametrization to a tensor in a module. Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, the module will return the parametrized version ``parametrization(module.weight)``. If the original tensor requires a gradient, the backward pass will differentiate through :attr:`parametrization`, and the optimizer will update the tensor accordingly. The first time that a module registers a parametrization, this function will add an attribute ``parametrizations`` to the module of type :class:`~ParametrizationList`. The list of parametrizations on the tensor ``weight`` will be accessible under ``module.parametrizations.weight``. The original tensor will be accessible under ``module.parametrizations.weight.original``. Parametrizations may be concatenated by registering several parametrizations on the same attribute. The training mode of a registered parametrization is updated on registration to match the training mode of the host module Parametrized parameters and buffers have an inbuilt caching system that can be activated using the context manager :func:`cached`. A :attr:`parametrization` may optionally implement a method with signature .. code-block:: python def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] This method is called on the unparametrized tensor when the first parametrization is registered to compute the initial value of the original tensor. If this method is not implemented, the original tensor will be just the unparametrized tensor. If all the parametrizations registered on a tensor implement `right_inverse` it is possible to initialize a parametrized tensor by assigning to it, as shown in the example below. It is possible for the first parametrization to depend on several inputs. This may be implemented returning a tuple of tensors from ``right_inverse`` (see the example implementation of a ``RankOne`` parametrization below). In this case, the unconstrained tensors are also located under ``module.parametrizations.weight`` with names ``original0``, ``original1``,... .. note:: If unsafe=False (default) both the forward and right_inverse methods will be called once to perform a number of consistency checks. If unsafe=True, then right_inverse will be called if the tensor is not parametrized, and nothing will be called otherwise. .. note:: In most situations, ``right_inverse`` will be a function such that ``forward(right_inverse(X)) == X`` (see `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_). Sometimes, when the parametrization is not surjective, it may be reasonable to relax this. .. warning:: If a parametrization depends on several inputs, :func:`~register_parametrization` will register a number of new parameters. If such parametrization is registered after the optimizer is created, these new parameters will need to be added manually to the optimizer. See :meth:`torch.Optimizer.add_param_group`. Args: module (nn.Module): module on which to register the parametrization tensor_name (str): name of the parameter or buffer on which to register the parametrization parametrization (nn.Module): the parametrization to register Keyword args: unsafe (bool): a boolean flag that denotes whether the parametrization may change the dtype and shape of the tensor. Default: `False` Warning: the parametrization is not checked for consistency upon registration. Enable this flag at your own risk. Raises: ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` Examples: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) >>> import torch >>> import torch.nn as nn >>> import torch.nn.utils.parametrize as P >>> >>> class Symmetric(nn.Module): >>> def forward(self, X): >>> return X.triu() + X.triu(1).T # Return a symmetric matrix >>> >>> def right_inverse(self, A): >>> return A.triu() >>> >>> m = nn.Linear(5, 5) >>> P.register_parametrization(m, "weight", Symmetric()) >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric True >>> A = torch.rand(5, 5) >>> A = A + A.T # A is now symmetric >>> m.weight = A # Initialize the weight to be the symmetric matrix A >>> print(torch.allclose(m.weight, A)) True >>> class RankOne(nn.Module): >>> def forward(self, x, y): >>> # Form a rank 1 matrix multiplying two vectors >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) >>> >>> def right_inverse(self, Z): >>> # Project Z onto the rank 1 matrices >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) >>> # Return rescaled singular vectors >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt >>> >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) 1 """parametrization.train(module.training)ifis_parametrized(module,tensor_name):# Correctness checks.# If A is the space of tensors with shape and dtype equal to module.weight# we check that parametrization.forward and parametrization.right_inverse are# functions from A to Aifnotunsafe:Y=getattr(module,tensor_name)X=parametrization(Y)ifnotisinstance(X,Tensor):raiseValueError(f"A parametrization must return a tensor. Got {type(X).__name__}.")ifX.dtype!=Y.dtype:raiseValueError("Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n"f"module.{tensor_name}.dtype: {Y.dtype}\n"f"parametrization(module.{tensor_name}).dtype: {X.dtype}")ifX.shape!=Y.shape:raiseValueError("Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n"f"module.{tensor_name}.shape: {Y.shape}\n"f"parametrization(module.{tensor_name}).shape: {X.shape}")ifhasattr(parametrization,"right_inverse"):try:Z=parametrization.right_inverse(X)# type: ignore[operator]exceptNotImplementedError:passelse:ifnotisinstance(Z,Tensor):raiseValueError(f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}")ifZ.dtype!=Y.dtype:raiseValueError("The tensor returned by parametrization.right_inverse must have the same dtype "f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"f"module.{tensor_name}.dtype: {Y.dtype}\n"f"returned dtype: {Z.dtype}")ifZ.shape!=Y.shape:raiseValueError("The tensor returned by parametrization.right_inverse must have the same shape "f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"f"module.{tensor_name}.shape: {Y.shape}\n"f"returned shape: {Z.shape}")# else right_inverse is assumed to be the identity# add the new parametrization to the parametrization listassertisinstance(module.parametrizations,ModuleDict)# Make mypy happymodule.parametrizations[tensor_name].append(parametrization)# If unsafe was True in previous parametrization, keep it enabledmodule.parametrizations[tensor_name].unsafe|=unsafe# type: ignore[index, union-attr]eliftensor_nameinmodule._buffersortensor_nameinmodule._parameters:# Set the parametrization mechanism# Fetch the original buffer or parameteroriginal=getattr(module,tensor_name)# We create this early to check for possible errorsparametrizations=ParametrizationList([parametrization],original,unsafe=unsafe)# Delete the previous parameter or bufferdelattr(module,tensor_name)# If this is the first parametrization registered on the module,# we prepare the module to inject the propertyifnotis_parametrized(module):# Change the class_inject_new_class(module)# Inject a ``ModuleDict`` into the instance under module.parametrizationsmodule.parametrizations=ModuleDict()# Add a property into the class_inject_property(module,tensor_name)# Add a ParametrizationListassertisinstance(module.parametrizations,ModuleDict)# Make mypy happymodule.parametrizations[tensor_name]=parametrizationselse:raiseValueError(f"Module '{module}' does not have a parameter, a buffer, or a "f"parametrized element with name '{tensor_name}'")returnmodule
[docs]defis_parametrized(module:Module,tensor_name:Optional[str]=None)->bool:r"""Returns ``True`` if module has an active parametrization. If the argument :attr:`tensor_name` is specified, returns ``True`` if ``module[tensor_name]`` is parametrized. Args: module (nn.Module): module to query tensor_name (str, optional): attribute in the module to query Default: ``None`` """parametrizations=getattr(module,"parametrizations",None)ifparametrizationsisNoneornotisinstance(parametrizations,ModuleDict):returnFalseiftensor_nameisNone:# Check that there is at least one parametrized buffer or Parameterreturnlen(parametrizations)>0else:returntensor_nameinparametrizations
[docs]defremove_parametrizations(module:Module,tensor_name:str,leave_parametrized:bool=True)->Module:r"""Removes the parametrizations on a tensor in a module. - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to its current output. In this case, the parametrization shall not change the ``dtype`` of the tensor. - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to the unparametrised tensor in ``module.parametrizations[tensor_name].original``. This is only possible when the parametrization depends on just one tensor. Args: module (nn.Module): module from which remove the parametrization tensor_name (str): name of the parametrization to be removed leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. Default: ``True`` Returns: Module: module Raises: ValueError: if ``module[tensor_name]`` is not parametrized ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors """ifnotis_parametrized(module,tensor_name):raiseValueError(f"Module {module} does not have a parametrization on {tensor_name}")# Fetch the original tensorassertisinstance(module.parametrizations,ModuleDict)# Make mypy happyparametrizations=module.parametrizations[tensor_name]ifparametrizations.is_tensor:original=parametrizations.originalifleave_parametrized:withtorch.no_grad():t=getattr(module,tensor_name)# We know they have the same dtype because we have checked this when registering the# parametrizations. As such, we can use set_# We do this so that the parameter does not to change the id()# This way the user does not need to update the optimizerwithtorch.no_grad():iftype(original)istorch.Tensor:original.set_(t)else:try:original.set_(t)exceptRuntimeErrorase:# TODO: Fix this for tensor subclasses that are parameters:# RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().raiseRuntimeError("Calling remove_parametrizations() with leave_parametrized=True ""for a parameter that is an instance of a tensor subclass requires ""set_() to be implemented correctly for the tensor subclass. Either ""set leave_parametrized=False or provide a working implementation for ""set_() in the tensor subclass.")fromeelse:ifleave_parametrized:# We cannot use no_grad because we need to know whether one or more# original tensors required gradt=getattr(module,tensor_name)# We'll have to trust the user to add it to the optimizeroriginal=Parameter(t)ift.requires_gradelsetelse:raiseValueError("Cannot leave unparametrized (`leave_parametrized=False`) a tensor ""that is parametrized in terms of a sequence of tensors.")# Delete the property that manages the parametrizationdelattr(module.__class__,tensor_name)# Delete the ParametrizationListdelmodule.parametrizations[tensor_name]# Restore the parameter / buffer into the main class_register_parameter_or_buffer(module,tensor_name,original)# Roll back the parametrized class if no other buffer or parameter# is currently parametrized in this classifnotis_parametrized(module):delattr(module,"parametrizations")# Restore classorig_cls=module.__class__.__bases__[0]module.__class__=orig_clsreturnmodule
deftype_before_parametrizations(module:Module)->type:r"""Returns the module type before parametrizations were applied and if not, then it returns the module type. Args: module (nn.Module): module to get type of """ifis_parametrized(module):returnmodule.__class__.__bases__[0]else:returntype(module)deftransfer_parametrizations_and_params(from_module:Module,to_module:Module,tensor_name:Optional[str]=None)->Module:r"""Transfers parametrizations and the parameters they parametrize from from_module to to_module. If tensor_name is specified, only transfers the specified parameter, otherwise transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them. Does nothing if from_module is not parametrized. Args: from_module (nn.Module): module to transfer from to_module (nn.Module): module to transfer to tensor_name (str, optional): parameter to transfer Returns: Module: to_module """ifis_parametrized(from_module):assertisinstance(from_module.parametrizations,ModuleDict)# for mypy# get list of all params or the single param to transferparameters_to_transfer:Union[list,ModuleDict]=(from_module.parametrizationsiftensor_nameisNoneelse[tensor_name])asserthasattr(parameters_to_transfer,"__iter__")# for mypyforparameter_nameinparameters_to_transfer:# initialize the to-be-transferred param in to_module if it doesn't exist alreadyifnothasattr(to_module,parameter_name):setattr(to_module,parameter_name,Parameter(getattr(from_module,parameter_name)),)# apply the params's parametrizations to to_moduleforparam_funcinfrom_module.parametrizations[parameter_name]:register_parametrization(to_module,parameter_name,param_func)assertisinstance(to_module.parametrizations,ModuleDict)# for mypy# make values match, original values can be stored in either original or# original0, original1..., need to check both casesifhasattr(from_module.parametrizations[parameter_name],"original"):to_module.parametrizations[parameter_name].original= \
from_module.parametrizations[parameter_name].originalelse:num=0orig_num="original"+str(num)# loop through each original# until all values have been setwhilehasattr(from_module.parametrizations[parameter_name],orig_num):setattr(to_module.parametrizations[parameter_name],orig_num,getattr(from_module.parametrizations[parameter_name],orig_num),)num=num+1orig_num="original"+str(num)returnto_module
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.