Source code for torch.distributed.tensor.parallel.api
# Copyright (c) Meta Platforms, Inc. and affiliatesfromtypingimportDict,Unionimporttorchimporttorch.nnasnnimporttorch.distributed._tensor.randomasrandomfromtorch.distributed._tensorimport(DeviceMesh,distribute_module,distribute_tensor,Replicate,Shard,)fromtorch.distributed._tensor.randomimport(is_rng_supported_mesh,TensorParallelRNGTracker,)fromtorch.distributed.tensor.parallel._utilsimport_create_1d_device_meshfromtorch.distributed.tensor.parallel.styleimport(ColwiseParallel,PairwiseParallel,ParallelStyle,RowwiseParallel,)__all__=["parallelize_module",]
[docs]defparallelize_module(# type: ignore[return]module:nn.Module,device_mesh:DeviceMesh,parallelize_plan:Union[ParallelStyle,Dict[str,ParallelStyle]],tp_mesh_dim:int=0,)->nn.Module:""" The API to apply Tensor Parallelism (TP) in PyTorch. We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains :class:`ParallelStyle`, which indicates how user wants the module or sub_module to be parallelized. User can also specify different parallel style per module fully qualified name (FQN). The API supports 2D parallelism natively by accepting an n-dimension device_mesh and users just need to specify the dimension where we perform tensor parallelism on. Args: module (:class:`nn.Module`): Module to be parallelized. device_mesh (:class:`DeviceMesh`): Object which describes the mesh topology of devices for the DTensor. parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]): The plan used to parallelize the module. It can be either a :class:`ParallelStyle` object which contains how we prepare input/output for Tensor Parallelism or it can be a dict of module FQN and its corresponding :class:`ParallelStyle` object. tp_mesh_dim (int): The dimension of ``device_mesh`` where we perform Tensor Parallelism on. Return: A :class:`nn.Module` object parallelized. Example:: >>> # xdoctest: +SKIP("distributed") >>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel >>> >>> # Define the module. >>> m = Model(...) >>> m = parallelize_module(m, PairwiseParallel()) >>> .. warning:: ``PairwiseParallel`` comes with constraints for now. If you need finer granularity, you need to pass in a dict of module FQN and parallel style instead. """torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")# instantiate a TP RNG state tracker if it's not thereif(is_rng_supported_mesh(device_mesh)andnotisinstance(random._rng_tracker,TensorParallelRNGTracker)):random._rng_tracker=TensorParallelRNGTracker(device_mesh.device_type)# TODO: we should allow user to pass in the default seed from a configrandom._rng_tracker._manual_seed(device_mesh,base_seed=1234,tp_dim=tp_mesh_dim)# By default we execute random ops in non-tensor-parallel region. If users want# to execute in tensor-parallel region, they can manually set this field to True# after parallelizing the model.random._rng_tracker.distribute_region_enabled=Falseifdevice_mesh.ndim>1:device_mesh=_create_1d_device_mesh(device_mesh,tp_mesh_dim)ifisinstance(parallelize_plan,ParallelStyle):# RowwiseParallel or ColwiseParallelifisinstance(parallelize_plan,(ColwiseParallel,RowwiseParallel)):return_parallelize_linear(module,device_mesh,parallelize_plan)# PairwiseParallelif_is_mlp_for_pairwise_parallel(module):return_parallelize_mlp(module,device_mesh,parallelize_plan)else:forn,minmodule.named_children():module.register_module(n,parallelize_module(m,device_mesh,parallelize_plan))returnmoduleelifisinstance(parallelize_plan,dict):formodule_path,parallelize_styleinparallelize_plan.items():sub_module=module.get_submodule(module_path)parent_module=moduleif"."inmodule_path:parent_module_path=".".join(module_path.split(".")[:-1])parent_module=module.get_submodule(parent_module_path)module_path=module_path.split(".")[-1]parent_module.register_module(# type: ignore[call-arg] # pyre-ignore[20]module_path,parallelize_module(# type: ignore[arg-type]sub_module,device_mesh,parallelize_style# type: ignore[arg-type] # pyre-ignore[6]),)returnmoduleelse:raiseRuntimeError(# pyre-ignore[7]"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"f" parallelize_plan, {type(parallelize_plan)} found!")
def_is_mlp_for_pairwise_parallel(module:nn.Module)->bool:""" Traverse through all the immediate children of the given module and count the number of Linear module. If the number is more than one, we return True. Args: module (:class:`nn.Module`): Module to be traversed and counted. Return: A bool which specifies whether the module is MLP supported or not. .. warning:: The traversal is not recursive for now. """linear_submodules=list(filter(lambdax:isinstance(x,nn.Linear),module.children()))returnlen(linear_submodules)>1def_rowwise_parallelize_linear_fn(name:str,module:nn.Module,device_mesh:DeviceMesh,)->None:""" This function parallelizes the input :class:`nn.Linear` module in :class:`RowwiseParallel` style. Args: name (str): Name of the input module. module (:class:`nn.Module`): The :class:`nn.Linear` module to be parallelized. device_mesh (:class:`DeviceMesh`): Object which describes the mesh topology of devices. Returns: None """forname,paraminmodule.named_parameters():dist_spec=([Shard(1)]ifname=="weight"else[Replicate()]# type: ignore[list-item])dist_param=torch.nn.Parameter(distribute_tensor(param,device_mesh,dist_spec))module.register_parameter(name,dist_param)def_colwise_parallelize_linear_fn(name:str,module:nn.Module,device_mesh:DeviceMesh,)->None:""" This function parallelizes the input :class:`nn.Linear` module in :class:`ColwiseParallel` style. Args: name (str): Name of the input module. module (:class:`nn.Module`): The :class:`nn.Linear` module to be parallelized. device_mesh (:class:`DeviceMesh`): Object which describes the mesh topology of devices. Returns: None """forname,paraminmodule.named_parameters():dist_param=torch.nn.Parameter(distribute_tensor(param,device_mesh,[Shard(0)]))module.register_parameter(name,dist_param)def_parallelize_linear(module:nn.Module,device_mesh:DeviceMesh,parallel_style:ParallelStyle=ColwiseParallel(),tp_mesh_dim:int=0,)->nn.Module:""" This function requires that the input module be an object of :class:`nn.Linear`. The module will be parallelized over a 1-d :class:`DeviceMesh` based on the :class:`ParallelStyle`. Args: module (:class:`nn.Module`): The module to be parallelized. device_mesh (:class:`DeviceMesh`): Object which describes the mesh topology of devices for the :class:`DTensor`. If the mesh is more than 1-dimensional, we will use the mesh dim of `device_mesh` specified by `tp_mesh_dim`. parallel_style (:class:`ParallelStyle`, optional): The object which describes how the :class:`nn.Linear` module should be distributed over :class:`DeviceMesh` and how the input and output should be prepared for Tensor Parallelism. :class:`RowwiseStyle`: weight is sharded on dim 1 and bias is replicate. :class:`ColwiseStyle`: weight and bias are both sharded on dim 0. Default: :class:`ColwiseParallel` tp_mesh_dim (int): The dimension of :class:`DeviceMesh` on which we perform Tensor Parallelism. Default: 0 Return: A :class:`nn.Module` object parallelized. """ifnotisinstance(module,nn.Linear):raiseRuntimeError(f"Expect a torch.nn.Linear module but received {type(module)}!")ifnotisinstance(parallel_style,ParallelStyle):raiseRuntimeError("Expect a ParallelStyle object but received"f" {type(parallel_style)}!")ifdevice_mesh.ndim>1:device_mesh=_create_1d_device_mesh(device_mesh,tp_mesh_dim)ifisinstance(parallel_style,(RowwiseParallel)):distribute_module(module,device_mesh,_rowwise_parallelize_linear_fn,input_fn=parallel_style._prepare_input,# type: ignore[arg-type, misc] # pyre-ignore[6]output_fn=parallel_style._prepare_output,# type: ignore[arg-type, misc] # pyre-ignore[6])elifisinstance(parallel_style,(ColwiseParallel)):distribute_module(module,device_mesh,_colwise_parallelize_linear_fn,input_fn=parallel_style._prepare_input,# type: ignore[arg-type, misc] # pyre-ignore[6]output_fn=parallel_style._prepare_output,# type: ignore[arg-type, misc] # pyre-ignore[6])else:raiseRuntimeError(f"{type(parallel_style)} is not supported!")returnmoduledef_parallelize_mlp(module:nn.Module,device_mesh:DeviceMesh,parallel_style:ParallelStyle=PairwiseParallel(),tp_mesh_dim:int=0,)->nn.Module:""" This function assumes the input module is a sequence of nn.Linear and we parallelize the module based on the given parallel style. We don't change the FQN of each sub-module and replace each parameter in place. Args: module (:class:`nn.Module`): Module to be parallelized. device_mesh (:class:`DeviceMesh`): Object which describes the mesh topology of devices. parallel_style (:class:`ParallelStyle`): Object which contains how we prepare input/output for Tensor Parallelism. tp_mesh_dim (int): The dimension of `device_mesh` where we perform Tensor Parallelism on. Return: A :class:`nn.Module` object parallelized. .. warning:: We only support ``PairwiseParallel`` right now. """ifnotisinstance(parallel_style,PairwiseParallel):raiseNotImplementedError("Only support PairwiseParallel for MLP parallelization.")ifnot_is_mlp_for_pairwise_parallel(module):raiseRuntimeError("More than one nn.Linear needed for a MLP.")ifdevice_mesh.ndim>1:device_mesh=_create_1d_device_mesh(device_mesh,tp_mesh_dim)linear_submodules=list(filter(lambdax:isinstance(x,nn.Linear),module.children()))mlp_last_even_layer=(len(linear_submodules)//2)*2foriinrange(mlp_last_even_layer):m=linear_submodules[i]ifi%2==0:# Col-wise Parallelize the linear layerdistribute_module(m,device_mesh,_colwise_parallelize_linear_fn,input_fn=parallel_style._prepare_input# type: ignore[arg-type, misc] # pyre-ignore[6]ifi==0elseNone,)else:# Row-wise Parallelize the linear layerdistribute_module(m,device_mesh,_rowwise_parallelize_linear_fn,output_fn=parallel_style._prepare_output# type: ignore[arg-type, misc] # pyre-ignore[6]ifi==(mlp_last_even_layer-1)elseNone,)returnmodule
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.