Note
Click here to download the full example code
Tips for Loading an nn.Module
from a Checkpoint¶
Author: Mikayla Gawarecki
If you’re loading a checkpoint and want to reduce compute and memory as much as possible, this tutorial shares some recommended practices. In particular, we will discuss
The
mmap
keyword argument ontorch.load
The
torch.device()
context managerThe
assign
keyword argument onnn.Module.load_state_dict()
Note
This recipe requires PyTorch 2.1.0 or later.
Let us consider a simple nn.Module
that contains a list of Linear layers:
import torch
from torch import nn
import time
class SomeModule(torch.nn.Module):
def __init__(self, size):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(size, size) for i in range(10)])
def forward(self, x):
return self.linears(x)
m = SomeModule(1000)
torch.save(m.state_dict(), 'checkpoint.pth')
The following snippet demonstrates the use of the the mmap
keyword argument
to torch.load
, the torch.device()
context manager and the assign
keyword argument to nn.Module.load_state_dict()
.
state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
with torch.device('meta'):
meta_m = SomeModule(1000)
meta_m.load_state_dict(state_dict, assign=True)
<All keys matched successfully>
Compare the snippet below to the one above:
state_dict = torch.load('checkpoint.pth', weights_only=True)
m = SomeModule(1000)
m.load_state_dict(state_dict)
<All keys matched successfully>
The second example does not use any of the features listed above and will be less compute and memory efficient for loading a checkpoint. In the following sections, we will discuss each of the features in further detail.
Using torch.load(mmap=True)
¶
First, let us consider what happens when we load the checkpoint with torch.load
.
When we save a checkpoint with torch.save
, tensor storages are tagged with the device they are
saved on. With torch.load
, tensor storages will be loaded to the device
they were tagged with (unless this behavior is overridden using the
map_location
flag). For ease of explanation, let us assume that the tensors
were saved on CPU. This means that on the first line all tensor storages will be
loaded into CPU RAM, which can be undesirable when:
CPU RAM is smaller than the size of the checkpoint.
Waiting for the entire checkpoint to be loaded into RAM before performing, for example, some per-tensor processing.
start_time = time.time()
state_dict = torch.load('checkpoint.pth', weights_only=True)
end_time = time.time()
print(f"loading time without mmap={end_time - start_time}")
loading time without mmap=0.011675357818603516
The mmap
keyword argument to torch.load
attempts to solve the above two
problems. As its name implies, the mmap
keyword argument to torch.load
makes use of an mmap call
which maps a file on disk into virtual memory and lets the OS handle loading and
unloading into physical memory automatically. When this flag is passed, tensor
storages will be memory-mapped.
start_time = time.time()
state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
end_time = time.time()
print(f"loading time with mmap={end_time - start_time}")
loading time with mmap=0.002198457717895508
As mentioned above, one can use this argument to do per-tensor processing on a checkpoint without loading all tensor storages into CPU memory upfront. For example:
def my_special_routine(t, device):
# this could be a much fancier operation
return t.to(dtype=torch.bfloat16, device=device)
def my_processing_function(key, device):
t = state_dict[key]
processed_t = my_special_routine(t, device)
del t
state_dict[key] = processed_t
for key in state_dict.keys():
device = torch.device('cuda')
my_processing_function(key, device)
Using torch.device('meta')
¶
Next, let’s consider the creation of the module.
m = SomeModule(1000)
This allocates memory for all parameters/buffers and initializes them per
the default initialization schemes defined in SomeModule.__init__()
, which
is wasteful when we want to load a checkpoint for the following reasons:
The result of the initialization kernels will be overwritten by
load_state_dict()
without ever being used, so initialization is wasteful.We are allocating memory for these parameters/buffers in RAM while
torch.load
of the saved state dictionary also allocates memory in RAM for the parameters/buffers in the checkpoint.
In order to solve these two problems, we can use the torch.device()
context manager with device='meta'
when we instantiate the nn.Module()
.
The torch.device()
context manager makes sure that factory calls will be performed as if they
were passed the specified device
as an argument. Tensors on torch.device('meta')
do not
carry data. However, they possess all other metadata a tensor carries such as .size()
, .stride()
,
.requires_grad
, and others.
with torch.device('meta'):
new_m = SomeModule(1000)
Using load_state_dict(assign=True)
¶
Next, we consider the loading of the state dictionary.
m.load_state_dict(state_dict)
<All keys matched successfully>
nn.Module.load_state_dict()
is usually implemented via an in-place
param_in_model.copy_(param_in_state_dict)
. This means that the parameter/buffer
with the corresponding key in the state dictionary is copied into the
parameter/buffer in the nn.Module
.
However, an in-place copy into a tensor on the meta
device is a no-op.
In order to avoid this, we can pass the assign=True
keyword argument to
load_state_dict()
.
A caveat here is that since optimizers hold a reference to
nn.Module.parameters()
, the optimizer must be initialized after the module
is loaded from state dict if assign=True
is passed.
# As of PyTorch 2.3.0, one can use ``torch.__future__.set_swap_module_params_on_conversion`` to
# avoid this caveat. This `recipe <https://pytorch.org/tutorials/recipes/recipes/swap_tensors.html>`_
# provides more details.
new_m.load_state_dict(state_dict, assign=True)
# Before 2.3.0, this MUST be done AFTER the load_state_dict with assign.
# In versions >= 2.3.0, one can consider setting ``torch.__future__.set_swap_module_params_on_conversion``
opt = torch.optim.SGD(new_m.parameters(), lr=1e-3)
Conclusion¶
To recap, in this tutorial we learned about torch.load(mmap=True)
, the
torch.device()
context manager with device=meta
, and
nn.Module.load_state_dict(assign=True)
as well as how these tools could
be used to aid when loading a model from a checkpoint.
Total running time of the script: ( 0 minutes 0.226 seconds)