Source code for torch.distributed.tensor.parallel.api
# Copyright (c) Meta Platforms, Inc. and affiliatesfromtypingimportDict,Unionfromfnmatchimportfnmatchimporttorchimporttorch.distributed._tensor.randomasrandomimporttorch.nnasnnfromtorch.distributed._tensorimport(DeviceMesh,)fromtorch.distributed._tensor.randomimport(is_rng_supported_mesh,TensorParallelRNGTracker,)fromtorch.distributed.tensor.parallel._utilsimport_validate_tp_mesh_dimfromtorch.distributed.tensor.parallel.styleimport(ParallelStyle,)__all__=["parallelize_module",]
[docs]defparallelize_module(# type: ignore[return]module:nn.Module,device_mesh:DeviceMesh,parallelize_plan:Union[ParallelStyle,Dict[str,ParallelStyle]],)->nn.Module:""" Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan. We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains :class:`ParallelStyle`, which indicates how user wants the module or sub_module to be parallelized. User can also specify different parallel style per module fully qualified name (FQN). Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`, slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``) Args: module (:class:`nn.Module`): Module to be parallelized. device_mesh (:class:`DeviceMesh`): Object which describes the mesh topology of devices for the DTensor. parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]): The plan used to parallelize the module. It can be either a :class:`ParallelStyle` object which contains how we prepare input/output for Tensor Parallelism or it can be a dict of module FQN and its corresponding :class:`ParallelStyle` object. Return: A :class:`nn.Module` object parallelized. Example:: >>> # xdoctest: +SKIP("distributed") >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> >>> # Define the module. >>> m = Model(...) >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()}) >>> .. note:: For complex module architecture like Attention, MLP layers, we recommend composing different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass as a parallelize_plan, to achieves the desired sharding computation. """torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")_validate_tp_mesh_dim(device_mesh)# instantiate a TP RNG state tracker if it's not thereifis_rng_supported_mesh(device_mesh)andnotisinstance(random._rng_tracker,TensorParallelRNGTracker):random._rng_tracker=TensorParallelRNGTracker(device_mesh.device_type)# TODO: we should allow user to pass in the default seed from a configrandom._rng_tracker._manual_seed(device_mesh,base_seed=1234)# By default we execute random ops in non-tensor-parallel region. If users want# to execute in tensor-parallel region, they can manually set this field to True# after parallelizing the model.random._rng_tracker.distribute_region_enabled=Falseifisinstance(parallelize_plan,ParallelStyle):returnparallelize_plan._apply(module,device_mesh)elifisinstance(parallelize_plan,dict):formodule_path,parallelize_styleinparallelize_plan.items():path_splits=module_path.split(".")iflen(path_splits)==0:raiseValueError("Expect module path to be non-empty, but got empty string!")whilepath_splits:atom=path_splits.pop(0)matched_children=filter(# `t[0]` is child namelambdat:fnmatch(t[0],atom),module.named_children())# apply the plan to all matched submodulesfor_,submoduleinmatched_children:ifpath_splits:# we haven't reached the leaf, apply in dict styleleaf_path=".".join(path_splits)# rest of the path after `atom`parallelize_module(submodule,device_mesh,{leaf_path:parallelize_style})else:# otherwise, directly apply style to this submoduleparallelize_module(submodule,device_mesh,parallelize_style)returnmoduleelse:raiseTypeError(# pyre-ignore[7]"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"f" parallelize_plan, {type(parallelize_plan)} found!")
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.