• Tutorials >
  • Extension points in nn.Module for load_state_dict and tensor subclasses

Extension points in nn.Module for load_state_dict and tensor subclasses

Author: Mikayla Gawarecki

This recipe introduces a new utility function torch.utils.swap_tensors as well as two new extension points where it has been integrated in nn.Module:

  • nn.Module.to() and related methods

  • nn.Module.load_state_dict()


This recipe requires PyTorch 2.3.0 or later.


torch.utils.swap_tensors (hereafter referred to as swap_tensors) is a utility function that takes in two Python tensors and swaps them.

import torch
import torch.nn as nn
t1 = torch.arange(2)
t2 = torch.arange(3)
print(f"Before swapping, t1: {t1}, t2: {t2}")
torch.utils.swap_tensors(t1, t2)
print(f"After swapping, t1: {t1}, t2: {t2}")
Before swapping, t1: tensor([0, 1]), t2: tensor([0, 1, 2])
After swapping, t1: tensor([0, 1, 2]), t2: tensor([0, 1])

More specifically, swap_tensors swaps the Python __class__, __dict__ and __slots__ of the two tensors, as well as their associated at::Tensor.

Application to nn.Module

This utility is pertinent to nn.Module when a Python object outside of the module holds a reference to parameters of the module. If an nn.Module modifies any of its parameters out of place, the object holding references to the parameters will not see the change. A classic example of this is the optimizer, which holds a reference to the parameters of the nn.Module. This leads to a silent correctness issue where the optimizer.step() will run without error but the weights of the nn.Module will not be updated.

mod = torch.nn.Linear(1, 2, bias=False)
optimizer = torch.optim.SGD(mod.parameters())
print(f"weight in mod: {mod.weight}")
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
mod.weight = torch.nn.Parameter(2 * mod.weight)
print(f"weight in mod: {mod.weight}")
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
weight in mod: Parameter containing:
        [0.8300]], requires_grad=True)
weight in optimizer: [Parameter containing:
        [0.8300]], requires_grad=True)]
weight in mod: Parameter containing:
        [1.6600]], requires_grad=True)
weight in optimizer: [Parameter containing:
        [0.8300]], requires_grad=True)]


Depending on the value of the assign keyword argument passed to load_state_dict(), there are two ways to load the state_dict:

  • assign=False: preserves the properties of module.param and only takes the values from state_dict['param_name']

  • assign=True: preserves the properties and values of state_dict['param_name'].

Previously, these were implemented with in-place copy_ and __setattr__ respectively. With the existing implementation, each approach had its own limitations – assign=False imposes the constraint that the type of the parameter in the state_dict must be the same as the type of the parameter in the module while assign=True imposes the constraint that anything that holds references to the module’s parameters must be initialized after nn.Module.load_state_dict().

Now, we address both constraints by adding a swap_tensors path to load_state_dict() and introducing a new extension point torch.Tensor.module_load(self, other, assign=False). When the swap_tensors path is enabled via the __future__ mentioned above, we can use a __torch_function__ handler for module_load to apply a custom transformation to the value in the state_dict. The result of this transformation will be swapped with the parameter in the module.

In the following example, we will use the MyQuantizedLinearWeight subclass defined above to illustrate how we can use these features to apply a custom quantization scheme to the weights of a linear layer when loading the state_dict.

Recall that the __torch_function__ handler for module_load will be invoked if either self or other (in this case param or state_dict[param_key]) are MyQuantizedLinearWeight subclasses.

Assume that we expect the state_dict to contain plain tensors and the module to contain MyQuantizedLinearWeight parameters where we want the tensors in the state_dict to be transformed into the subclass. Then we can define a __torch_function__ handler for torch.Tensor.module_load as such:

def custom_torch_function(cls, func, types, args=(), kwargs=None):
    kwargs = {} if kwargs is None else kwargs

    if func is torch.Tensor.module_load:
        dest, src = args[0], args[1]
        assert type(dest) == cls and type(src) == torch.Tensor
        return MyQuantizedLinearWeight(src, dest.scale)
        with torch._C.DisableTorchFunctionSubclass():
                return func(*args, **kwargs)

MyQuantizedLinearWeight.__torch_function__ = custom_torch_function

First, let us create a skeleton of a model on the meta device to avoid materializing storages. We convert all weights in the modules to MyQuantizedLinearWeight subclasses while leaving biases intact.

def fn(m):
    if isinstance(m, nn.Linear):
        requires_grad = m.weight.requires_grad
        m.weight = torch.nn.Parameter(
                    MyQuantizedLinearWeight(m.weight, 0.5), requires_grad=requires_grad

with torch.device("meta"):
    m = nn.Linear(3, 5)

We can then load the state_dict. Observe that we use assign=True because for biases, we want to preserve the properties of the tensor in the state_dict (for example, we do not want the bias to be on the meta device after loading).

print(f"Before: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
print(f"m.state_dict() before load_state_dict():\n {m.state_dict()}")
state_dict = nn.Linear(3, 5).state_dict()
print(f"state_dict:\n {state_dict}")
m.load_state_dict(state_dict, assign=True)
print(f"After: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
print(f"m.state_dict() after load_state_dict():\n {m.state_dict()}")
Before: id(weight)=140309112167600, id(bias)=140309349403920
m.state_dict() before load_state_dict():
 OrderedDict([('weight', MyQuantizedLinearWeight(tensor(..., device='meta', size=(5, 3)), scale=0.5)), ('bias', tensor(..., device='meta', size=(5,)))])
 OrderedDict([('weight', tensor([[ 0.2430,  0.5155,  0.3337],
        [-0.2524,  0.3333,  0.1033],
        [ 0.2932, -0.3519, -0.5715],
        [-0.2231, -0.4428,  0.4737],
        [ 0.1663,  0.2391,  0.1826]])), ('bias', tensor([-0.0100,  0.4518, -0.4102,  0.0364, -0.3941]))])
After: id(weight)=140309112167600, id(bias)=140309349403920
m.state_dict() after load_state_dict():
 OrderedDict([('weight', MyQuantizedLinearWeight(tensor([[ 0.2430,  0.5155,  0.3337],
        [-0.2524,  0.3333,  0.1033],
        [ 0.2932, -0.3519, -0.5715],
        [-0.2231, -0.4428,  0.4737],
        [ 0.1663,  0.2391,  0.1826]]), scale=0.5)), ('bias', tensor([-0.0100,  0.4518, -0.4102,  0.0364, -0.3941]))])

The above is a toy example of how we can use the new extension point in nn.Module.load_state_dict(). One can also imagine alternate scenarios such as when we have tensor subclasses in the state_dict and plain nn.Parameters/ tensors in the module or when both are tensor subclasses. Based on the use case, we can define the __torch_function__ handler for module_load to apply the transforms as needed.


In this recipe, we learned about swap_tensors, the importance of preserving references for parameters in nn.Module as well as how to use the two new extension points that are gated by torch.__future__.set_swap_module_params_on_conversion.

Total running time of the script: ( 0 minutes 0.019 seconds)

Gallery generated by Sphinx-Gallery


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources