Source code for torch.distributed.tensor.parallel.api
# Copyright (c) Meta Platforms, Inc. and affiliatesimportwarningsfromfnmatchimportfnmatchfromtypingimportDict,Optional,Unionimporttorchimporttorch.nnasnnfromtorch.distributed.device_meshimport_mesh_resources,DeviceMeshfromtorch.distributed.tensor.parallel._utilsimport_validate_tp_mesh_dimfromtorch.distributed.tensor.parallel.styleimportParallelStyle__all__=["parallelize_module"]
[docs]defparallelize_module(# type: ignore[return]module:nn.Module,device_mesh:Optional[DeviceMesh]=None,parallelize_plan:Optional[Union[ParallelStyle,Dict[str,ParallelStyle]]]=None,)->nn.Module:""" Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan. We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains :class:`ParallelStyle`, which indicates how user wants the module or sub_module to be parallelized. User can also specify different parallel style per module fully qualified name (FQN). Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`, slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``) Args: module (:class:`nn.Module`): Module to be parallelized. device_mesh (:class:`DeviceMesh`, optional): Object which describes the mesh topology of devices for the DTensor. If not specified, the call must be under a DeviceMesh context. parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]], optional): The plan used to parallelize the module. It can be either a :class:`ParallelStyle` object which contains how we prepare input/output for Tensor Parallelism or it can be a dict of module FQN and its corresponding :class:`ParallelStyle` object. If not specified, the call will do nothing at the moment. Return: A :class:`nn.Module` object parallelized. Example:: >>> # xdoctest: +SKIP("distributed") >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> >>> # Define the module. >>> m = Model(...) >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()}) >>> .. note:: For complex module architecture like Attention, MLP layers, we recommend composing different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass as a parallelize_plan, to achieves the desired sharding computation. """torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")device_mesh=device_meshor_mesh_resources.get_current_mesh()_validate_tp_mesh_dim(device_mesh)ifparallelize_planisNone:warnings.warn("No parallelize_plan is provided and auto-parallel is not supported ""at the moment, so this parallelize_module call will do nothing.")returnmodule# note: The RNG tracker will be initialized in distribute_tensor() call if it hasn't# been initialized.ifisinstance(parallelize_plan,ParallelStyle):returnparallelize_plan._apply(module,device_mesh)elifisinstance(parallelize_plan,dict):formodule_path,parallelize_styleinparallelize_plan.items():path_splits=module_path.split(".")iflen(path_splits)==0:raiseValueError("Expect module path to be non-empty, but got empty string!")whilepath_splits:atom=path_splits.pop(0)matched_children=filter(# `t[0]` is child namelambdat:fnmatch(t[0],atom),module.named_children(),)# apply the plan to all matched submodulesfor_,submoduleinmatched_children:ifpath_splits:# we haven't reached the leaf, apply in dict styleleaf_path=".".join(path_splits)# rest of the path after `atom`parallelize_module(submodule,device_mesh,{leaf_path:parallelize_style})else:# otherwise, directly apply style to this submoduleparallelize_module(submodule,device_mesh,parallelize_style)returnmoduleelse:raiseTypeError(# pyre-ignore[7]"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"f" parallelize_plan, {type(parallelize_plan)} found!")
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.