Shortcuts

Source code for torchdistx.deferred_init

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, Optional

from torch import Tensor
from torch.nn import Module

# We import `fake` to monkey-patch `repr()` of `Tensor`.
from . import fake  # noqa: F401
from . import _C


[docs]def deferred_init(module_fn: Callable[..., Module], *args, **kwargs) -> Module: """Defers the initialization of a ``Module``. This function forces all tensors constructed within ``module_fn`` to be fake while also recording all operations performed on them. The modules and tensors returned from ``module_fn`` can later be instantiated using the :func:`materialize_tensor` and :func:`materialize_module` functions. Args: module_fn: A callable that takes arbitrary number of arguments and returns a ``Module`` instance. args, kwargs: The positional and keyword arguments to be passed to ``module_fn``. .. Warning:: The operations performed on the parameters and buffers of a module will only be recorded while inside ``deferred_init()``. Avoid making changes to a module after its returned from ``deferred_init()``; otherwise it cannot be correctly materialized. """ _C.enable_deferred_init(True) try: return module_fn(*args, **kwargs) finally: _C.enable_deferred_init(False)
[docs]def materialize_tensor(tensor: Tensor) -> Tensor: """Materializes ``tensor``. Args: tensor: The tensor instance to materialize. .. Warning:: Once materialized a fake tensor will hold a reference to its materialized version. In order to avoid memory leaks make sure to dispose it when it is no longer required. """ return _C.materialize_tensor(tensor)
[docs]def materialize_module( module: Module, buffers_only: bool = False, check_fn: Optional[Callable[[Module], bool]] = None, ) -> None: """Materializes ``module`` and its descendant modules. Args: module: The module instance to materialize. buffers_only: A boolean value indicating whether to materialize the buffer tensors only. check_fn: An optional callable which takes a ``Module`` instance and returns a boolean value indicating whether to materialize it. """ def materialize_tensors(tensors: Dict[str, Optional[Tensor]]) -> None: for key, tensor in tensors.items(): if tensor is None: continue try: tensors[key] = _C.materialize_tensor(tensor) except ValueError: raise ValueError(f"'{key}' has already been materialized.") from None # Materialize the child modules recursively. for m in module.children(): materialize_module(m, buffers_only, check_fn) # Materialize this module, possibly based on a check. if check_fn is None or check_fn(module): if not buffers_only: materialize_tensors(module._parameters) # type: ignore[arg-type] materialize_tensors(module._buffers)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources