.. currentmodule:: torchdistx.deferred_init Deferred Module Initialization ============================== TL;DR ------- Deferred Module Initialization feature consists of a :func:`deferred_init` function that constructs ``Module`` instances without allocating storage for their tensors, and the accompanying :func:`materialize_module` and :func:`materialize_tensor` functions that can fully or partially materialize modules constructed by :func:`deferred_init`. The feature is meant to be used if a module is memory-wise too big or computationally too expensive to construct on a single machine, but needs to be inspected for various reasons before being initialized. Problem ------- With ever increasing model sizes, it is becoming increasingly common for models to exceed the memory or compute capacity of a single machine or accelerator. This means training such models requires some sharding (a.k.a. partitioning) strategy to distribute parts of the model onto different computing nodes. However techniques such as 3D parallelism used to apply these strategies often need access to the model architecture to decide on the optimal strategy and this represents a chicken-egg problem. Automated parallelism libraries (e.g. FSDP, DeepSpeed) either completely ignore this problem, meaning they expect the model to fit on a single machine, or they have some rudimentary workarounds to partially overcome it. For instance they use a technique that sequentially initializes model parameters while sharding them on-the-fly based on some predefined memory-size threshold. However the limitation of such workarounds is that these libraries are not able to see the whole architecture of the model that would enable them to make smarter sharding decisions. What is Deferred Module Initialization? --------------------------------------- Deferred Module Initialization addresses the problem mentioned above by offering three functions. :func:`deferred_init` is a non-intrusive function that enables users to defer the initialization of a ``Module`` by skipping storage allocation for its parameters and buffers while also keeping a record of the operations performed on them in an in-memory graph. :func:`materialize_module` and :func:`materialize_tensor` are the accompanying functions that materialize (i.e. initialize) tensors or modules constructed within a previous :func:`deferred_init` call by re-playing the operations recorded at that time. API --- Initialization ^^^^^^^^^^^^^^ As mentioned above ``deferred_init()`` is the "entry point" of the API and has the following signature: .. autofunction:: deferred_init .. note:: The graph structure generated by ``deferred_init()`` is fairly simple, albeit holds information that is specifically meant to materialize in-memory tensors as if they were initialized without deferral. In that sense its implementation and its purpose diverges from the much larger and feature rich solutions such as torch.fx and TorchScript. Materialization ^^^^^^^^^^^^^^^ Modules, parameters, and buffers constructed within a :func:`deferred_init` call can later be materialized using the ``materialize_module()`` and ``materialize_tensor()`` functions. .. autofunction:: materialize_module .. autofunction:: materialize_tensor Examples -------- The simplest use case is to construct a module using :func:`deferred_init` and then later materialize it after some form of inspection using :func:`materialize_module`: :: >>> import torch >>> >>> from torchdistx.deferred_init import deferred_init, materialize_module >>> >>> # Notice that `m` does not have any storage even though it appears to be >>> # be a module allocated on CPU. >>> m = deferred_init(torch.nn.Linear, 5, 1): >>> m.weight Parameter containing: tensor(..., device='cpu', requires_grad=True, fake=True) >>> >>> # Do some form of inspection. >>> ... >>> >>> # At the end materialize the module. >>> materialize_module(m) >>> m.weight Parameter containing: tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, 4.5915e-41]], requires_grad=True) It is also possible to materialize only a subset of modules, parameters, or buffers of a large model: :: >>> import torch >>> >>> from torchdistx.deferred_init import ( ... deferred_init, ... materialize_module, ... materialize_tensor, ... ) >>> >>> class MyLargeModel(torch.nn.Module): ... ... >>> >>> m = deferred_init(MyLargeModel): >>> >>> # Do some form of inspection (e.g. determine sharding strategy). >>> ... >>> >>> # Only materialize `sublayer1` and `sublayer2`. >>> materialize_module(m.sublayer1) >>> materialize_module(m.sublayer2) >>> >>> # Or materialize an individual parameter or buffer. >>> materialized_param = materialize_tensor(m.sublayer1.param1) :func:`deferred_init` skips storage allocation even for explicitly passed device arguments: :: >>> import torch >>> >>> from torchdistx.deferred_init import deferred_init, materialize_module >>> >>> class MyModule(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.param = torch.nn.Parameter(torch.ones([3], device="cpu")) ... >>> m = deferred_init(MyModule): >>> m.param Parameter containing: tensor(..., device='cpu', size=(10, 10), requires_grad=True, fake=True) >>> >>> materialize_module(m) >>> m.param Parameter containing: tensor([1., 1., 1.], requires_grad=True) Lazy modules can be used along with :func:`deferred_init()` by wrapping the module construction and the dry-run call in a single function as demonstrated below: :: >>> import torch >>> >>> from torchdistx.deferred_init import deferred_init >>> >>> def MyLazyModule(out_features: int): ... lazy_m = torch.nn.LazyLinear(out_features) ... ... # Dry-run the module to infer the parameter and buffer shapes. ... lazy_m(torch.ones([10, 10])) ... ... return lazy_m >>> >>> m = deferred_init(MyLazyModule, 10) However note that :func:`deferred_init` and materialize functions use a "best effort" approach and are not guaranteed to always succeed. See the `Common Failure Patterns`_ section below to learn more. Common Failure Patterns ----------------------- **A module using an operator that is not supported by the meta backend:** Internally :func:`deferred_init` relies on the meta backend. If the module to be constructed by :func:`deferred_init` uses an operator that is not yet supported by the meta backend, the operator call will fail. Fortunately such failures are easy to spot since the returned error message will clearly indicate which operator was the culprit. The solution in such case is to introduce meta backend support for the failed operation. **Mutable operator arguments:** Although almost all PyTorch operators use either primitives (e.g. integers, floating-point numbers) or tensors as parameter types, if an operator accepts a mutable argument (e.g. a storage, blob, future) with ``Tensor`` being an exception, :func:`deferred_init` will deliberately fail the operation since we cannot guarantee that the argument will have the same state during materialization. **In-place updated external tensors and inference tensors:** As a follow-up of mutable arguments, if a tensor constructed from external data (e.g. via ``torch.load()``, ``torch.from_numpy()``) is used as an argument to a meta operation within :func:`deferred_init`, its version counter will be tracked similar to Autograd. A change to the version counter, which practically means an in-place update to the tensor, will be checked during materialization and, if detected, an error will be raised since that would prevent the correct materialization. The rules are stricter for inference tensors; since in-place updates cannot be tracked for them any materialization call using an inference tensor as an argument will raise an error. **A module using tolist() or numpy() functions in its constructor:** Currently Deferred Module Initialization does not support tracing calls to ``tolist()`` and ``numpy()`` functions. We consider this a temporary limitation and will work with the PyTorch core team to mitigate it in future releases.