Source code for torch.distributed.tensor.parallel.api
# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
from fnmatch import fnmatch
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
from torch.distributed.tensor.parallel.style import ParallelStyle
__all__ = ["parallelize_module"]
[docs]def parallelize_module( # type: ignore[return]
module: nn.Module,
device_mesh: Optional[DeviceMesh] = None,
parallelize_plan: Optional[Union[ParallelStyle, Dict[str, ParallelStyle]]] = None,
) -> nn.Module:
"""
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
:class:`ParallelStyle`, which indicates how user wants the module or sub_module
to be parallelized.
User can also specify different parallel style per module fully qualified name (FQN).
Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
Args:
module (:class:`nn.Module`):
Module to be parallelized.
device_mesh (:class:`DeviceMesh`, optional):
Object which describes the mesh topology of devices for the DTensor.
If not specified, the call must be under a DeviceMesh context.
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]], optional):
The plan used to parallelize the module. It can be either a
:class:`ParallelStyle` object which contains how we prepare
input/output for Tensor Parallelism or it can be a dict of module
FQN and its corresponding :class:`ParallelStyle` object. If not
specified, the call will do nothing at the moment.
Return:
A :class:`nn.Module` object parallelized.
Example::
>>> # xdoctest: +SKIP("distributed")
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>>
>>> # Define the module.
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>>
.. note:: For complex module architecture like Attention, MLP layers, we recommend composing
different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
as a parallelize_plan, to achieves the desired sharding computation.
"""
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
_validate_tp_mesh_dim(device_mesh)
if parallelize_plan is None:
warnings.warn(
"No parallelize_plan is provided and auto-parallel is not supported "
"at the moment, so this parallelize_module call will do nothing."
)
return module
# note: The RNG tracker will be initialized in distribute_tensor() call if it hasn't
# been initialized.
if isinstance(parallelize_plan, ParallelStyle):
return parallelize_plan._apply(module, device_mesh)
elif isinstance(parallelize_plan, dict):
for module_path, parallelize_style in parallelize_plan.items():
path_splits = module_path.split(".")
if len(path_splits) == 0:
raise ValueError(
"Expect module path to be non-empty, but got empty string!"
)
while path_splits:
atom = path_splits.pop(0)
matched_children = filter(
# `t[0]` is child name
lambda t: fnmatch(t[0], atom),
module.named_children(),
)
# apply the plan to all matched submodules
for _, submodule in matched_children:
if path_splits:
# we haven't reached the leaf, apply in dict style
leaf_path = ".".join(
path_splits
) # rest of the path after `atom`
parallelize_module(
submodule, device_mesh, {leaf_path: parallelize_style}
)
else:
# otherwise, directly apply style to this submodule
parallelize_module(submodule, device_mesh, parallelize_style)
return module
else:
raise TypeError( # pyre-ignore[7]
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
f" parallelize_plan, {type(parallelize_plan)} found!"
)