• Docs >
  • Deferred Module Initialization
Shortcuts

Deferred Module Initialization

TL;DR

Deferred Module Initialization feature consists of a deferred_init() function that constructs Module instances without allocating storage for their tensors, and the accompanying materialize_module() and materialize_tensor() functions that can fully or partially materialize modules constructed by deferred_init(). The feature is meant to be used if a module is memory-wise too big or computationally too expensive to construct on a single machine, but needs to be inspected for various reasons before being initialized.

Problem

With ever increasing model sizes, it is becoming increasingly common for models to exceed the memory or compute capacity of a single machine or accelerator. This means training such models requires some sharding (a.k.a. partitioning) strategy to distribute parts of the model onto different computing nodes. However techniques such as 3D parallelism used to apply these strategies often need access to the model architecture to decide on the optimal strategy and this represents a chicken-egg problem.

Automated parallelism libraries (e.g. FSDP, DeepSpeed) either completely ignore this problem, meaning they expect the model to fit on a single machine, or they have some rudimentary workarounds to partially overcome it. For instance they use a technique that sequentially initializes model parameters while sharding them on-the-fly based on some predefined memory-size threshold. However the limitation of such workarounds is that these libraries are not able to see the whole architecture of the model that would enable them to make smarter sharding decisions.

What is Deferred Module Initialization?

Deferred Module Initialization addresses the problem mentioned above by offering three functions. deferred_init() is a non-intrusive function that enables users to defer the initialization of a Module by skipping storage allocation for its parameters and buffers while also keeping a record of the operations performed on them in an in-memory graph. materialize_module() and materialize_tensor() are the accompanying functions that materialize (i.e. initialize) tensors or modules constructed within a previous deferred_init() call by re-playing the operations recorded at that time.

API

Initialization

As mentioned above deferred_init() is the “entry point” of the API and has the following signature:

torchdistx.deferred_init.deferred_init(module_fn, *args, **kwargs)[source]

Defers the initialization of a Module.

This function forces all tensors constructed within module_fn to be fake while also recording all operations performed on them. The modules and tensors returned from module_fn can later be instantiated using the materialize_tensor() and materialize_module() functions.

Parameters
  • module_fn (Callable[[...], torch.nn.modules.module.Module]) – A callable that takes arbitrary number of arguments and returns a Module instance.

  • args – The positional and keyword arguments to be passed to module_fn.

  • kwargs – The positional and keyword arguments to be passed to module_fn.

Return type

torch.nn.modules.module.Module

Warning

The operations performed on the parameters and buffers of a module will only be recorded while inside deferred_init(). Avoid making changes to a module after its returned from deferred_init(); otherwise it cannot be correctly materialized.

Note

The graph structure generated by deferred_init() is fairly simple, albeit holds information that is specifically meant to materialize in-memory tensors as if they were initialized without deferral. In that sense its implementation and its purpose diverges from the much larger and feature rich solutions such as torch.fx and TorchScript.

Materialization

Modules, parameters, and buffers constructed within a deferred_init() call can later be materialized using the materialize_module() and materialize_tensor() functions.

torchdistx.deferred_init.materialize_module(module, buffers_only=False, check_fn=None)[source]

Materializes module and its descendant modules.

Parameters
  • module (torch.nn.modules.module.Module) – The module instance to materialize.

  • buffers_only (bool) – A boolean value indicating whether to materialize the buffer tensors only.

  • check_fn (Optional[Callable[[torch.nn.modules.module.Module], bool]]) – An optional callable which takes a Module instance and returns a boolean value indicating whether to materialize it.

Return type

None

torchdistx.deferred_init.materialize_tensor(tensor)[source]

Materializes tensor.

Parameters

tensor (torch.Tensor) – The tensor instance to materialize.

Return type

torch.Tensor

Warning

Once materialized a fake tensor will hold a reference to its materialized version. In order to avoid memory leaks make sure to dispose it when it is no longer required.

Examples

The simplest use case is to construct a module using deferred_init() and then later materialize it after some form of inspection using materialize_module():

>>> import torch
>>>
>>> from torchdistx.deferred_init import deferred_init, materialize_module
>>>
>>> # Notice that `m` does not have any storage even though it appears to be
>>> # be a module allocated on CPU.
>>> m = deferred_init(torch.nn.Linear, 5, 1):
>>> m.weight
Parameter containing:
tensor(..., device='cpu', requires_grad=True, fake=True)
>>>
>>> # Do some form of inspection.
>>> ...
>>>
>>> # At the end materialize the module.
>>> materialize_module(m)
>>> m.weight
Parameter containing:
tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00,
                 -1.4677e+24, 4.5915e-41]], requires_grad=True)

It is also possible to materialize only a subset of modules, parameters, or buffers of a large model:

>>> import torch
>>>
>>> from torchdistx.deferred_init import (
...     deferred_init,
...     materialize_module,
...     materialize_tensor,
... )
>>>
>>> class MyLargeModel(torch.nn.Module):
...     ...
>>>
>>> m = deferred_init(MyLargeModel):
>>>
>>> # Do some form of inspection (e.g. determine sharding strategy).
>>> ...
>>>
>>> # Only materialize `sublayer1` and `sublayer2`.
>>> materialize_module(m.sublayer1)
>>> materialize_module(m.sublayer2)
>>>
>>> # Or materialize an individual parameter or buffer.
>>> materialized_param = materialize_tensor(m.sublayer1.param1)

deferred_init() skips storage allocation even for explicitly passed device arguments:

>>> import torch
>>>
>>> from torchdistx.deferred_init import deferred_init, materialize_module
>>>
>>> class MyModule(torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.param = torch.nn.Parameter(torch.ones([3], device="cpu"))
...
>>> m = deferred_init(MyModule):
>>> m.param
Parameter containing:
tensor(..., device='cpu', size=(10, 10), requires_grad=True, fake=True)
>>>
>>> materialize_module(m)
>>> m.param
Parameter containing:
tensor([1., 1., 1.], requires_grad=True)

Lazy modules can be used along with deferred_init() by wrapping the module construction and the dry-run call in a single function as demonstrated below:

>>> import torch
>>>
>>> from torchdistx.deferred_init import deferred_init
>>>
>>> def MyLazyModule(out_features: int):
...    lazy_m = torch.nn.LazyLinear(out_features)
...
...    # Dry-run the module to infer the parameter and buffer shapes.
...    lazy_m(torch.ones([10, 10]))
...
...    return lazy_m
>>>
>>> m = deferred_init(MyLazyModule, 10)

However note that deferred_init() and materialize functions use a “best effort” approach and are not guaranteed to always succeed. See the Common Failure Patterns section below to learn more.

Common Failure Patterns

A module using an operator that is not supported by the meta backend: Internally deferred_init() relies on the meta backend. If the module to be constructed by deferred_init() uses an operator that is not yet supported by the meta backend, the operator call will fail. Fortunately such failures are easy to spot since the returned error message will clearly indicate which operator was the culprit. The solution in such case is to introduce meta backend support for the failed operation.

Mutable operator arguments: Although almost all PyTorch operators use either primitives (e.g. integers, floating-point numbers) or tensors as parameter types, if an operator accepts a mutable argument (e.g. a storage, blob, future) with Tensor being an exception, deferred_init() will deliberately fail the operation since we cannot guarantee that the argument will have the same state during materialization.

In-place updated external tensors and inference tensors: As a follow-up of mutable arguments, if a tensor constructed from external data (e.g. via torch.load(), torch.from_numpy()) is used as an argument to a meta operation within deferred_init(), its version counter will be tracked similar to Autograd. A change to the version counter, which practically means an in-place update to the tensor, will be checked during materialization and, if detected, an error will be raised since that would prevent the correct materialization. The rules are stricter for inference tensors; since in-place updates cannot be tracked for them any materialization call using an inference tensor as an argument will raise an error.

A module using tolist() or numpy() functions in its constructor: Currently Deferred Module Initialization does not support tracing calls to tolist() and numpy() functions. We consider this a temporary limitation and will work with the PyTorch core team to mitigate it in future releases.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources