Deferred Module Initialization¶
TL;DR¶
Deferred Module Initialization feature consists of a deferred_init()
function that constructs Module
instances without allocating storage for
their tensors, and the accompanying materialize_module()
and
materialize_tensor()
functions that can fully or partially materialize
modules constructed by deferred_init()
. The feature is meant to be used if
a module is memory-wise too big or computationally too expensive to construct on
a single machine, but needs to be inspected for various reasons before being
initialized.
Problem¶
With ever increasing model sizes, it is becoming increasingly common for models to exceed the memory or compute capacity of a single machine or accelerator. This means training such models requires some sharding (a.k.a. partitioning) strategy to distribute parts of the model onto different computing nodes. However techniques such as 3D parallelism used to apply these strategies often need access to the model architecture to decide on the optimal strategy and this represents a chicken-egg problem.
Automated parallelism libraries (e.g. FSDP, DeepSpeed) either completely ignore this problem, meaning they expect the model to fit on a single machine, or they have some rudimentary workarounds to partially overcome it. For instance they use a technique that sequentially initializes model parameters while sharding them on-the-fly based on some predefined memory-size threshold. However the limitation of such workarounds is that these libraries are not able to see the whole architecture of the model that would enable them to make smarter sharding decisions.
What is Deferred Module Initialization?¶
Deferred Module Initialization addresses the problem mentioned above by offering
three functions. deferred_init()
is a non-intrusive function that enables
users to defer the initialization of a Module
by skipping storage allocation
for its parameters and buffers while also keeping a record of the operations
performed on them in an in-memory graph. materialize_module()
and
materialize_tensor()
are the accompanying functions that materialize
(i.e. initialize) tensors or modules constructed within a previous
deferred_init()
call by re-playing the operations recorded at that time.
API¶
Initialization¶
As mentioned above deferred_init()
is the “entry point” of the API and has
the following signature:
- torchdistx.deferred_init.deferred_init(module_fn, *args, **kwargs)[source]¶
Defers the initialization of a
Module
.This function forces all tensors constructed within
module_fn
to be fake while also recording all operations performed on them. The modules and tensors returned frommodule_fn
can later be instantiated using thematerialize_tensor()
andmaterialize_module()
functions.- Parameters
module_fn (Callable[[...], torch.nn.modules.module.Module]) – A callable that takes arbitrary number of arguments and returns a
Module
instance.args – The positional and keyword arguments to be passed to
module_fn
.kwargs – The positional and keyword arguments to be passed to
module_fn
.
- Return type
torch.nn.modules.module.Module
Warning
The operations performed on the parameters and buffers of a module will only be recorded while inside
deferred_init()
. Avoid making changes to a module after its returned fromdeferred_init()
; otherwise it cannot be correctly materialized.
Note
The graph structure generated by deferred_init()
is fairly simple, albeit
holds information that is specifically meant to materialize in-memory tensors
as if they were initialized without deferral. In that sense its
implementation and its purpose diverges from the much larger and feature rich
solutions such as torch.fx and TorchScript.
Materialization¶
Modules, parameters, and buffers constructed within a deferred_init()
call
can later be materialized using the materialize_module()
and
materialize_tensor()
functions.
- torchdistx.deferred_init.materialize_module(module, buffers_only=False, check_fn=None)[source]¶
Materializes
module
and its descendant modules.- Parameters
module (torch.nn.modules.module.Module) – The module instance to materialize.
buffers_only (bool) – A boolean value indicating whether to materialize the buffer tensors only.
check_fn (Optional[Callable[[torch.nn.modules.module.Module], bool]]) – An optional callable which takes a
Module
instance and returns a boolean value indicating whether to materialize it.
- Return type
None
- torchdistx.deferred_init.materialize_tensor(tensor)[source]¶
Materializes
tensor
.- Parameters
tensor (torch.Tensor) – The tensor instance to materialize.
- Return type
Warning
Once materialized a fake tensor will hold a reference to its materialized version. In order to avoid memory leaks make sure to dispose it when it is no longer required.
Examples¶
The simplest use case is to construct a module using deferred_init()
and
then later materialize it after some form of inspection using
materialize_module()
:
>>> import torch
>>>
>>> from torchdistx.deferred_init import deferred_init, materialize_module
>>>
>>> # Notice that `m` does not have any storage even though it appears to be
>>> # be a module allocated on CPU.
>>> m = deferred_init(torch.nn.Linear, 5, 1):
>>> m.weight
Parameter containing:
tensor(..., device='cpu', requires_grad=True, fake=True)
>>>
>>> # Do some form of inspection.
>>> ...
>>>
>>> # At the end materialize the module.
>>> materialize_module(m)
>>> m.weight
Parameter containing:
tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00,
-1.4677e+24, 4.5915e-41]], requires_grad=True)
It is also possible to materialize only a subset of modules, parameters, or buffers of a large model:
>>> import torch
>>>
>>> from torchdistx.deferred_init import (
... deferred_init,
... materialize_module,
... materialize_tensor,
... )
>>>
>>> class MyLargeModel(torch.nn.Module):
... ...
>>>
>>> m = deferred_init(MyLargeModel):
>>>
>>> # Do some form of inspection (e.g. determine sharding strategy).
>>> ...
>>>
>>> # Only materialize `sublayer1` and `sublayer2`.
>>> materialize_module(m.sublayer1)
>>> materialize_module(m.sublayer2)
>>>
>>> # Or materialize an individual parameter or buffer.
>>> materialized_param = materialize_tensor(m.sublayer1.param1)
deferred_init()
skips storage allocation even for explicitly passed device
arguments:
>>> import torch
>>>
>>> from torchdistx.deferred_init import deferred_init, materialize_module
>>>
>>> class MyModule(torch.nn.Module):
... def __init__(self):
... super().__init__()
... self.param = torch.nn.Parameter(torch.ones([3], device="cpu"))
...
>>> m = deferred_init(MyModule):
>>> m.param
Parameter containing:
tensor(..., device='cpu', size=(10, 10), requires_grad=True, fake=True)
>>>
>>> materialize_module(m)
>>> m.param
Parameter containing:
tensor([1., 1., 1.], requires_grad=True)
Lazy modules can be used along with deferred_init()
by wrapping the
module construction and the dry-run call in a single function as demonstrated
below:
>>> import torch
>>>
>>> from torchdistx.deferred_init import deferred_init
>>>
>>> def MyLazyModule(out_features: int):
... lazy_m = torch.nn.LazyLinear(out_features)
...
... # Dry-run the module to infer the parameter and buffer shapes.
... lazy_m(torch.ones([10, 10]))
...
... return lazy_m
>>>
>>> m = deferred_init(MyLazyModule, 10)
However note that deferred_init()
and materialize functions use a “best
effort” approach and are not guaranteed to always succeed. See the
Common Failure Patterns section below to learn more.
Common Failure Patterns¶
A module using an operator that is not supported by the meta backend:
Internally deferred_init()
relies on the meta backend. If the module to be
constructed by deferred_init()
uses an operator that is not yet supported
by the meta backend, the operator call will fail. Fortunately such failures are
easy to spot since the returned error message will clearly indicate which
operator was the culprit. The solution in such case is to introduce meta backend
support for the failed operation.
Mutable operator arguments: Although almost all PyTorch operators use either
primitives (e.g. integers, floating-point numbers) or tensors as parameter
types, if an operator accepts a mutable argument (e.g. a storage, blob, future)
with Tensor
being an exception, deferred_init()
will deliberately fail
the operation since we cannot guarantee that the argument will have the same
state during materialization.
In-place updated external tensors and inference tensors: As a follow-up of
mutable arguments, if a tensor constructed from external data (e.g. via
torch.load()
, torch.from_numpy()
) is used as an argument to a meta
operation within deferred_init()
, its version counter will be tracked
similar to Autograd. A change to the version counter, which practically means
an in-place update to the tensor, will be checked during materialization and, if
detected, an error will be raised since that would prevent the correct
materialization. The rules are stricter for inference tensors; since in-place
updates cannot be tracked for them any materialization call using an inference
tensor as an argument will raise an error.
A module using tolist() or numpy() functions in its constructor: Currently
Deferred Module Initialization does not support tracing calls to tolist()
and numpy()
functions. We consider this a temporary limitation and will work
with the PyTorch core team to mitigate it in future releases.