Shortcuts

Source code for torch.nn.utils.parametrize

import torch
from torch.nn.modules.container import ModuleList, ModuleDict, Module
from torch.nn.parameter import Parameter
from torch import Tensor

import collections
import copyreg
from copy import deepcopy
from contextlib import contextmanager
from typing import Union, Optional, Dict, Tuple, Sequence

__all__ = ['cached', 'ParametrizationList', 'register_parametrization', 'is_parametrized', 'remove_parametrizations',
           'type_before_parametrizations', 'transfer_parametrizations_and_params']

_cache_enabled = 0
_cache: Dict[Tuple[int, str], Optional[Tensor]] = {}


[docs]@contextmanager def cached(): r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`. The value of the parametrized objects is computed and cached the first time they are required when this context manager is active. The cached values are discarded when leaving the context manager. This is useful when using a parametrized parameter more than once in the forward pass. An example of this is when parametrizing the recurrent kernel of an RNN or when sharing weights. The simplest way to activate the cache is by wrapping the forward pass of the neural network .. code-block:: python import torch.nn.utils.parametrize as P ... with P.cached(): output = model(inputs) in training and evaluation. One may also wrap the parts of the modules that use several times the parametrized tensors. For example, the loop of an RNN with a parametrized recurrent kernel: .. code-block:: python with P.cached(): for x in xs: out_rnn = self.rnn_cell(x, out_rnn) """ global _cache global _cache_enabled _cache_enabled += 1 try: yield finally: _cache_enabled -= 1 if not _cache_enabled: _cache = {}
def _register_parameter_or_buffer(module, name, X): if isinstance(X, Parameter): module.register_parameter(name, X) else: module.register_buffer(name, X)
[docs]class ParametrizationList(ModuleList): r"""A sequential container that holds and manages the ``original`` or ``original0``, ``original1``, ... parameters or buffers of a parametrized :class:`torch.nn.Module`. It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` has been parametrized with :func:`register_parametrization`. If the first registered parametrization has a ``right_inverse`` that returns one tensor or does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity), it will hold the tensor under the name ``original``. If it has a ``right_inverse`` that returns more than one tensor, these will be registered as ``original0``, ``original1``, ... .. warning:: This class is used internally by :func:`register_parametrization`. It is documented here for completeness. It shall not be instantiated by the user. Args: modules (sequence): sequence of modules representing the parametrizations original (Parameter or Tensor): parameter or buffer that is parametrized unsafe (bool): a boolean flag that denotes whether the parametrization may change the dtype and shape of the tensor. Default: `False` Warning: the parametrization is not checked for consistency upon registration. Enable this flag at your own risk. """ original: Tensor unsafe: bool def __init__( self, modules: Sequence[Module], original: Union[Tensor, Parameter], unsafe: bool = False ) -> None: # We require this because we need to treat differently the first parametrization # This should never throw, unless this class is used from the outside if len(modules) == 0: raise ValueError("ParametrizationList requires one or more modules.") super().__init__(modules) self.unsafe = unsafe # In plain words: # module.weight must keep its dtype and shape. # Furthermore, if there is no right_inverse or the right_inverse returns a tensor, # this should be of the same dtype as the original tensor # # We check that the following invariants hold: # X = module.weight # Y = param.right_inverse(X) # assert isinstance(Y, Tensor) or # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y)) # Z = param(Y) if isinstance(Y, Tensor) else param(*Y) # # Consistency checks # assert X.dtype == Z.dtype and X.shape == Z.shape # # If it has one input, this allows to be able to use set_ to be able to # # move data to/from the original tensor without changing its id (which is what the # # optimizer uses to track parameters) # if isinstance(Y, Tensor) # assert X.dtype == Y.dtype # Below we use original = X, new = Y original_shape = original.shape original_dtype = original.dtype # Compute new with torch.no_grad(): new = original for module in reversed(self): # type: ignore[call-overload] if hasattr(module, "right_inverse"): try: new = module.right_inverse(new) except NotImplementedError: pass # else, or if it throws, we assume that right_inverse is the identity if not isinstance(new, Tensor) and not isinstance(new, collections.abc.Sequence): raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " f"Got {type(new).__name__}") # Set the number of original tensors self.is_tensor = isinstance(new, Tensor) self.ntensors = 1 if self.is_tensor else len(new) # Register the tensor(s) if self.is_tensor: if original.dtype != new.dtype: raise ValueError( "When `right_inverse` outputs one tensor, it may not change the dtype.\n" f"original.dtype: {original.dtype}\n" f"right_inverse(original).dtype: {new.dtype}" ) # Set the original to original so that the user does not need to re-register the parameter # manually in the optimiser with torch.no_grad(): original.set_(new) # type: ignore[call-overload] _register_parameter_or_buffer(self, "original", original) else: for i, originali in enumerate(new): if not isinstance(originali, Tensor): raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors " "(list, tuple...). " f"Got element {i} of the sequence with type {type(originali).__name__}.") # If the original tensor was a Parameter that required grad, we expect the user to # add the new parameters to the optimizer after registering the parametrization # (this is documented) if isinstance(original, Parameter): originali = Parameter(originali) originali.requires_grad_(original.requires_grad) _register_parameter_or_buffer(self, f"original{i}", originali) if not self.unsafe: # Consistency checks: # Since f : A -> B, right_inverse : B -> A, Z and original should live in B # Z = forward(right_inverse(original)) Z = self() if not isinstance(Z, Tensor): raise ValueError( f"A parametrization must return a tensor. Got {type(Z).__name__}." ) if Z.dtype != original_dtype: raise ValueError( "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n" f"unparametrized dtype: {original_dtype}\n" f"parametrized dtype: {Z.dtype}" ) if Z.shape != original_shape: raise ValueError( "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n" f"unparametrized shape: {original_shape}\n" f"parametrized shape: {Z.shape}" )
[docs] def right_inverse(self, value: Tensor) -> None: r"""Calls the methods ``right_inverse`` (see :func:`register_parametrization`) of the parametrizations in the inverse order they were registered in. Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor or in ``self.original0``, ``self.original1``, ... if it outputs several. Args: value (Tensor): Value to which initialize the module """ # All the exceptions in this function should almost never throw. # They could throw if, for example, right_inverse function returns a different # dtype when given a different input, which should most likely be caused by a # bug in the user's code with torch.no_grad(): # See https://github.com/pytorch/pytorch/issues/53103 for module in reversed(self): # type: ignore[call-overload] if hasattr(module, "right_inverse"): value = module.right_inverse(value) else: raise RuntimeError(f"parametrization {type(module).__name__} does not implement " "right_inverse.") if self.is_tensor: # These exceptions should only throw when a right_inverse function does not # return the same dtype for every input, which should most likely be caused by a bug if not isinstance(value, Tensor): raise ValueError( f"`right_inverse` should return a tensor. Got {type(value).__name__}" ) if value.dtype != self.original.dtype: raise ValueError( f"The tensor returned by `right_inverse` has dtype {value.dtype} " f"while `original` has dtype {self.original.dtype}" ) # We know that the result is going to have the same dtype self.original.set_(value) # type: ignore[call-overload] else: if not isinstance(value, collections.abc.Sequence): raise ValueError( "'right_inverse' must return a sequence of tensors. " f"Got {type(value).__name__}." ) if len(value) != self.ntensors: raise ValueError( "'right_inverse' must return a sequence of tensors of length " f"{self.ntensors}. Got a sequence of length {len(value)}." ) for i, tensor in enumerate(value): original_i = getattr(self, f"original{i}") if not isinstance(tensor, Tensor): raise ValueError( f"`right_inverse` must return a sequence of tensors. " f"Got element {i} of type {type(tensor).__name__}" ) if original_i.dtype != tensor.dtype: raise ValueError( f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} " f"while `original{i}` has dtype {original_i.dtype}" ) original_i.set_(tensor)
def forward(self) -> Tensor: if torch.jit.is_scripting(): raise RuntimeError('Parametrization is not working with scripting.') # Unpack the originals for the first parametrization if self.is_tensor: x = self[0](self.original) else: originals = (getattr(self, f"original{i}") for i in range(self.ntensors)) x = self[0](*originals) # It's not possible to call self[1:] here, so we have to be a bit more cryptic # Also we want to skip all non-integer keys curr_idx = 1 while hasattr(self, str(curr_idx)): x = self[curr_idx](x) curr_idx += 1 return x
def _inject_new_class(module: Module) -> None: r"""Sets up a module to be parametrized. This works by substituting the class of the module by a class that extends it to be able to inject a property Args: module (nn.Module): module into which to inject the property """ cls = module.__class__ def default_deepcopy(self, memo): # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class. obj = memo.get(id(self), None) if obj is not None: return obj replica = self.__new__(self.__class__) memo[id(self)] = replica replica.__dict__ = deepcopy(self.__dict__, memo) # Also save all slots if they exist. slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] for slot in slots_to_save: if hasattr(self, slot): setattr(replica, slot, deepcopy(getattr(self, slot), memo)) return replica def getstate(self): raise RuntimeError( "Serialization of parametrized modules is only " "supported through state_dict(). See:\n" "https://pytorch.org/tutorials/beginner/saving_loading_models.html" "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" ) dct = {"__getstate__": getstate} # We don't allow serialization of parametrized modules but should still allow deepcopying. # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists. if not hasattr(cls, "__deepcopy__"): dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment] param_cls = type( f"Parametrized{cls.__name__}", (cls,), dct, ) module.__class__ = param_cls def _inject_property(module: Module, tensor_name: str) -> None: r"""Injects a property into module[tensor_name]. It assumes that the class in the module has already been modified from its original one using _inject_new_class and that the tensor under :attr:`tensor_name` has already been moved out Args: module (nn.Module): module into which to inject the property tensor_name (str): name of the name of the property to create """ # We check the precondition. # This should never fire if register_parametrization is correctly implemented assert not hasattr(module, tensor_name) @torch.jit.unused def get_cached_parametrization(parametrization) -> Tensor: global _cache key = (id(module), tensor_name) tensor = _cache.get(key) if tensor is None: tensor = parametrization() _cache[key] = tensor return tensor def get_parametrized(self) -> Tensor: if torch.jit.is_scripting(): raise RuntimeError('Parametrization is not working with scripting.') parametrization = self.parametrizations[tensor_name] if _cache_enabled: if torch.jit.is_scripting(): # Scripting raise RuntimeError('Caching is not implemented for scripting. ' 'Either disable caching or avoid scripting.') elif torch._C._get_tracing_state() is not None: # Tracing raise RuntimeError('Cannot trace a model while caching parametrizations.') else: return get_cached_parametrization(parametrization) else: # If caching is not active, this function just evaluates the parametrization return parametrization() def set_original(self, value: Tensor) -> None: if torch.jit.is_scripting(): raise RuntimeError('Parametrization is not working with scripting.') self.parametrizations[tensor_name].right_inverse(value) setattr(module.__class__, tensor_name, property(get_parametrized, set_original))
[docs]def register_parametrization( module: Module, tensor_name: str, parametrization: Module, *, unsafe: bool = False, ) -> Module: r"""Adds a parametrization to a tensor in a module. Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, the module will return the parametrized version ``parametrization(module.weight)``. If the original tensor requires a gradient, the backward pass will differentiate through :attr:`parametrization`, and the optimizer will update the tensor accordingly. The first time that a module registers a parametrization, this function will add an attribute ``parametrizations`` to the module of type :class:`~ParametrizationList`. The list of parametrizations on the tensor ``weight`` will be accessible under ``module.parametrizations.weight``. The original tensor will be accessible under ``module.parametrizations.weight.original``. Parametrizations may be concatenated by registering several parametrizations on the same attribute. The training mode of a registered parametrization is updated on registration to match the training mode of the host module Parametrized parameters and buffers have an inbuilt caching system that can be activated using the context manager :func:`cached`. A :attr:`parametrization` may optionally implement a method with signature .. code-block:: python def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] This method is called on the unparametrized tensor when the first parametrization is registered to compute the initial value of the original tensor. If this method is not implemented, the original tensor will be just the unparametrized tensor. If all the parametrizations registered on a tensor implement `right_inverse` it is possible to initialize a parametrized tensor by assigning to it, as shown in the example below. It is possible for the first parametrization to depend on several inputs. This may be implemented returning a tuple of tensors from ``right_inverse`` (see the example implementation of a ``RankOne`` parametrization below). In this case, the unconstrained tensors are also located under ``module.parametrizations.weight`` with names ``original0``, ``original1``,... .. note:: If unsafe=False (default) both the forward and right_inverse methods will be called once to perform a number of consistency checks. If unsafe=True, then right_inverse will be called if the tensor is not parametrized, and nothing will be called otherwise. .. note:: In most situations, ``right_inverse`` will be a function such that ``forward(right_inverse(X)) == X`` (see `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_). Sometimes, when the parametrization is not surjective, it may be reasonable to relax this. .. warning:: If a parametrization depends on several inputs, :func:`~register_parametrization` will register a number of new parameters. If such parametrization is registered after the optimizer is created, these new parameters will need to be added manually to the optimizer. See :meth:`torch.Optimizer.add_param_group`. Args: module (nn.Module): module on which to register the parametrization tensor_name (str): name of the parameter or buffer on which to register the parametrization parametrization (nn.Module): the parametrization to register Keyword args: unsafe (bool): a boolean flag that denotes whether the parametrization may change the dtype and shape of the tensor. Default: `False` Warning: the parametrization is not checked for consistency upon registration. Enable this flag at your own risk. Raises: ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` Examples: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) >>> import torch >>> import torch.nn as nn >>> import torch.nn.utils.parametrize as P >>> >>> class Symmetric(nn.Module): >>> def forward(self, X): >>> return X.triu() + X.triu(1).T # Return a symmetric matrix >>> >>> def right_inverse(self, A): >>> return A.triu() >>> >>> m = nn.Linear(5, 5) >>> P.register_parametrization(m, "weight", Symmetric()) >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric True >>> A = torch.rand(5, 5) >>> A = A + A.T # A is now symmetric >>> m.weight = A # Initialize the weight to be the symmetric matrix A >>> print(torch.allclose(m.weight, A)) True >>> class RankOne(nn.Module): >>> def forward(self, x, y): >>> # Form a rank 1 matrix multiplying two vectors >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) >>> >>> def right_inverse(self, Z): >>> # Project Z onto the rank 1 matrices >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) >>> # Return rescaled singular vectors >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt >>> >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) 1 """ parametrization.train(module.training) if is_parametrized(module, tensor_name): # Correctness checks. # If A is the space of tensors with shape and dtype equal to module.weight # we check that parametrization.forward and parametrization.right_inverse are # functions from A to A if not unsafe: Y = getattr(module, tensor_name) X = parametrization(Y) if not isinstance(X, Tensor): raise ValueError( f"A parametrization must return a tensor. Got {type(X).__name__}." ) if X.dtype != Y.dtype: raise ValueError( "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n" f"module.{tensor_name}.dtype: {Y.dtype}\n" f"parametrization(module.{tensor_name}).dtype: {X.dtype}" ) if X.shape != Y.shape: raise ValueError( "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n" f"module.{tensor_name}.shape: {Y.shape}\n" f"parametrization(module.{tensor_name}).shape: {X.shape}" ) if hasattr(parametrization, "right_inverse"): try: Z = parametrization.right_inverse(X) # type: ignore[operator] except NotImplementedError: pass else: if not isinstance(Z, Tensor): raise ValueError( f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" ) if Z.dtype != Y.dtype: raise ValueError( "The tensor returned by parametrization.right_inverse must have the same dtype " f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" f"module.{tensor_name}.dtype: {Y.dtype}\n" f"returned dtype: {Z.dtype}" ) if Z.shape != Y.shape: raise ValueError( "The tensor returned by parametrization.right_inverse must have the same shape " f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" f"module.{tensor_name}.shape: {Y.shape}\n" f"returned shape: {Z.shape}" ) # else right_inverse is assumed to be the identity # add the new parametrization to the parametrization list assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy module.parametrizations[tensor_name].append(parametrization) # If unsafe was True in previous parametrization, keep it enabled module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr] elif tensor_name in module._buffers or tensor_name in module._parameters: # Set the parametrization mechanism # Fetch the original buffer or parameter original = getattr(module, tensor_name) # We create this early to check for possible errors parametrizations = ParametrizationList([parametrization], original, unsafe=unsafe) # Delete the previous parameter or buffer delattr(module, tensor_name) # If this is the first parametrization registered on the module, # we prepare the module to inject the property if not is_parametrized(module): # Change the class _inject_new_class(module) # Inject a ``ModuleDict`` into the instance under module.parametrizations module.parametrizations = ModuleDict() # Add a property into the class _inject_property(module, tensor_name) # Add a ParametrizationList assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy module.parametrizations[tensor_name] = parametrizations else: raise ValueError( f"Module '{module}' does not have a parameter, a buffer, or a " f"parametrized element with name '{tensor_name}'" ) return module
[docs]def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: r"""Returns ``True`` if module has an active parametrization. If the argument :attr:`tensor_name` is specified, returns ``True`` if ``module[tensor_name]`` is parametrized. Args: module (nn.Module): module to query tensor_name (str, optional): attribute in the module to query Default: ``None`` """ parametrizations = getattr(module, "parametrizations", None) if parametrizations is None or not isinstance(parametrizations, ModuleDict): return False if tensor_name is None: # Check that there is at least one parametrized buffer or Parameter return len(parametrizations) > 0 else: return tensor_name in parametrizations
[docs]def remove_parametrizations( module: Module, tensor_name: str, leave_parametrized: bool = True ) -> Module: r"""Removes the parametrizations on a tensor in a module. - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to its current output. In this case, the parametrization shall not change the ``dtype`` of the tensor. - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to the unparametrised tensor in ``module.parametrizations[tensor_name].original``. This is only possible when the parametrization depends on just one tensor. Args: module (nn.Module): module from which remove the parametrization tensor_name (str): name of the parametrization to be removed leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. Default: ``True`` Returns: Module: module Raises: ValueError: if ``module[tensor_name]`` is not parametrized ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors """ if not is_parametrized(module, tensor_name): raise ValueError(f"Module {module} does not have a parametrization on {tensor_name}") # Fetch the original tensor assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy parametrizations = module.parametrizations[tensor_name] if parametrizations.is_tensor: original = parametrizations.original if leave_parametrized: with torch.no_grad(): t = getattr(module, tensor_name) # We know they have the same dtype because we have checked this when registering the # parametrizations. As such, we can use set_ # We do this so that the parameter does not to change the id() # This way the user does not need to update the optimizer with torch.no_grad(): if type(original) is torch.Tensor: original.set_(t) else: try: original.set_(t) except RuntimeError as e: # TODO: Fix this for tensor subclasses that are parameters: # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach(). raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True " "for a parameter that is an instance of a tensor subclass requires " "set_() to be implemented correctly for the tensor subclass. Either " "set leave_parametrized=False or provide a working implementation for " "set_() in the tensor subclass.") from e else: if leave_parametrized: # We cannot use no_grad because we need to know whether one or more # original tensors required grad t = getattr(module, tensor_name) # We'll have to trust the user to add it to the optimizer original = Parameter(t) if t.requires_grad else t else: raise ValueError("Cannot leave unparametrized (`leave_parametrized=False`) a tensor " "that is parametrized in terms of a sequence of tensors.") # Delete the property that manages the parametrization delattr(module.__class__, tensor_name) # Delete the ParametrizationList del module.parametrizations[tensor_name] # Restore the parameter / buffer into the main class _register_parameter_or_buffer(module, tensor_name, original) # Roll back the parametrized class if no other buffer or parameter # is currently parametrized in this class if not is_parametrized(module): delattr(module, "parametrizations") # Restore class orig_cls = module.__class__.__bases__[0] module.__class__ = orig_cls return module
def type_before_parametrizations(module: Module) -> type: r"""Returns the module type before parametrizations were applied and if not, then it returns the module type. Args: module (nn.Module): module to get type of """ if is_parametrized(module): return module.__class__.__bases__[0] else: return type(module) def transfer_parametrizations_and_params( from_module: Module, to_module: Module, tensor_name: Optional[str] = None ) -> Module: r"""Transfers parametrizations and the parameters they parametrize from from_module to to_module. If tensor_name is specified, only transfers the specified parameter, otherwise transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them. Does nothing if from_module is not parametrized. Args: from_module (nn.Module): module to transfer from to_module (nn.Module): module to transfer to tensor_name (str, optional): parameter to transfer Returns: Module: to_module """ if is_parametrized(from_module): assert isinstance(from_module.parametrizations, ModuleDict) # for mypy # get list of all params or the single param to transfer parameters_to_transfer: Union[list, ModuleDict] = ( from_module.parametrizations if tensor_name is None else [tensor_name] ) assert hasattr(parameters_to_transfer, "__iter__") # for mypy for parameter_name in parameters_to_transfer: # initialize the to-be-transferred param in to_module if it doesn't exist already if not hasattr(to_module, parameter_name): setattr( to_module, parameter_name, Parameter(getattr(from_module, parameter_name)), ) # apply the params's parametrizations to to_module for param_func in from_module.parametrizations[parameter_name]: register_parametrization(to_module, parameter_name, param_func) assert isinstance(to_module.parametrizations, ModuleDict) # for mypy # make values match, original values can be stored in either original or # original0, original1..., need to check both cases if hasattr(from_module.parametrizations[parameter_name], "original"): to_module.parametrizations[parameter_name].original = \ from_module.parametrizations[parameter_name].original else: num = 0 orig_num = "original" + str(num) # loop through each original# until all values have been set while hasattr(from_module.parametrizations[parameter_name], orig_num): setattr( to_module.parametrizations[parameter_name], orig_num, getattr(from_module.parametrizations[parameter_name], orig_num), ) num = num + 1 orig_num = "original" + str(num) return to_module

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources