Source code for torch.distributed.tensor.parallel.style
# mypy: allow-untyped-defs# Copyright (c) Meta Platforms, Inc. and affiliatesfromabcimportABC,abstractmethodfromfunctoolsimportpartialfromtypingimportAny,Dict,Optional,Tuple,Unionimporttorchimporttorch.nnasnnfromtorch.distributed.tensorimport(DeviceMesh,distribute_module,distribute_tensor,DTensor,Replicate,Shard,)fromtorch.distributed.tensor.placement_typesimportPlacement__all__=["ParallelStyle","RowwiseParallel","SequenceParallel","ColwiseParallel","PrepareModuleInput","PrepareModuleOutput",]classParallelStyle(ABC):""" The parallel style contract defines how the module or submodule should be parallelized. It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum flexibility for different kind of style implementations. """@abstractmethoddef_apply(self,module:nn.Module,device_mesh:DeviceMesh)->nn.Module:...
[docs]classColwiseParallel(ParallelStyle):""" Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding. Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention) Keyword Args: input_layouts (Placement, optional): The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be replicated. output_layouts (Placement, optional): The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is sharded on the last dimension. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. Returns: A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()}) >>> ... .. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``), keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size. """def__init__(self,*,input_layouts:Optional[Placement]=None,output_layouts:Optional[Placement]=None,use_local_output:bool=True,):super().__init__()self.input_layouts=(input_layoutsorReplicate(),)self.output_layouts=(output_layoutsorShard(-1),)# colwise linear runtime sharding (desired sharding):# 1. requires replicate input# 2. shard output on last dimself.desired_input_layouts=(Replicate(),)self.use_local_output=use_local_output@staticmethoddef_prepare_input_fn(input_layouts,desired_input_layouts,mod,inputs,device_mesh):# TODO: figure out dynamo support for instance method and switch this to instance method# annotate module input placements/sharding with input_layoutsinput_tensor=inputs[0]ifnotisinstance(input_tensor,DTensor):input_tensor=DTensor.from_local(input_tensor,device_mesh,input_layouts,run_check=False)# transform the input layouts to the desired layouts of ColwiseParallelifinput_layouts!=desired_input_layouts:input_tensor=input_tensor.redistribute(placements=desired_input_layouts,async_op=True)returninput_tensordef_partition_linear_fn(self,name,module,device_mesh):# colwise shard weight/bias to Shard(0), weight be Shard(0)# means Colwise as Linear is input * weight^T + bias, where# weight would become Shard(1)forname,paraminmodule.named_parameters():dist_param=nn.Parameter(distribute_tensor(param,device_mesh,[Shard(0)]))module.register_parameter(name,dist_param)def_partition_embedding_fn(self,name,module,device_mesh):# colwise shard embedding.weight is straight forward as Shard(1)forname,paraminmodule.named_parameters():dist_param=nn.Parameter(distribute_tensor(param,device_mesh,[Shard(1)]))module.register_parameter(name,dist_param)@staticmethoddef_prepare_output_fn(output_layouts,use_local_output,mod,outputs,device_mesh):# outputs is a shard on last dimension DTensor, i.e. Shard(-1)ifoutputs.placements!=output_layouts:outputs=outputs.redistribute(placements=output_layouts,async_op=True)# back to local tensorreturnoutputs.to_local()ifuse_local_outputelseoutputsdef_apply(self,module:nn.Module,device_mesh:DeviceMesh)->nn.Module:ifisinstance(module,nn.Linear):partition_fn=self._partition_linear_fnelifisinstance(module,nn.Embedding):partition_fn=self._partition_embedding_fnelse:raiseNotImplementedError("ColwiseParallel currently only support nn.Linear and nn.Embedding!")returndistribute_module(module,device_mesh,partition_fn,partial(self._prepare_input_fn,self.input_layouts,self.desired_input_layouts),partial(self._prepare_output_fn,self.output_layouts,self.use_local_output),)
[docs]classRowwiseParallel(ParallelStyle):""" Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding. Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention) Keyword Args: input_layouts (Placement, optional): The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension. output_layouts (Placement, optional): The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is replicated. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. Returns: A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), >>> ... """def__init__(self,*,input_layouts:Optional[Placement]=None,output_layouts:Optional[Placement]=None,use_local_output:bool=True,):super().__init__()self.input_layouts=(input_layoutsorShard(-1),)self.output_layouts=(output_layoutsorReplicate(),)self.use_local_output=use_local_output@staticmethoddef_prepare_input_fn(input_layouts,desired_input_layouts,mod,inputs,device_mesh):input_tensor=inputs[0]ifnotisinstance(input_tensor,DTensor):input_tensor=DTensor.from_local(input_tensor,device_mesh,input_layouts,run_check=False)ifinput_layouts!=desired_input_layouts:input_tensor=input_tensor.redistribute(placements=desired_input_layouts,async_op=True)returninput_tensordef_partition_linear_fn(self,name,module,device_mesh):# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)# means Rowwise as nn.Linear is input * weight^T + bias, where# weight would become Shard(0)module.register_parameter("weight",nn.Parameter(distribute_tensor(module.weight,device_mesh,[Shard(1)])),)ifgetattr(module,"bias",None)isnotNone:# The Linear module has biasmodule.register_parameter("bias",nn.Parameter(distribute_tensor(module.bias,device_mesh,[Replicate()])),)def_partition_embedding_fn(self,name,module,device_mesh):# rowwise shard embedding.weight is Shard(0)forname,paraminmodule.named_parameters():dist_param=nn.Parameter(distribute_tensor(param,device_mesh,[Shard(0)]))module.register_parameter(name,dist_param)@staticmethoddef_prepare_output_fn(output_layouts,use_local_output,mod,outputs,device_mesh):# Rowwise sharding produces partial output, depending on output layouts:# 1. to replicate -> allreduce# 2. to shard -> reduce_scatterifoutputs.placements!=output_layouts:outputs=outputs.redistribute(placements=output_layouts,async_op=True)# back to local tensor if use_local_output is Truereturnoutputs.to_local()ifuse_local_outputelseoutputsdef_apply(self,module:nn.Module,device_mesh:DeviceMesh)->nn.Module:ifisinstance(module,nn.Linear):partition_fn=self._partition_linear_fn# rowwise linear runtime sharding requires input tensor shard on last dimself.desired_input_layouts:Tuple[Placement,...]=(Shard(-1),)elifisinstance(module,nn.Embedding):partition_fn=self._partition_embedding_fn# rowwise embedding runtime sharding requires input tensor replicatedself.desired_input_layouts=(Replicate(),)else:raiseNotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!")returndistribute_module(module,device_mesh,partition_fn,partial(self._prepare_input_fn,self.input_layouts,self.desired_input_layouts),partial(self._prepare_output_fn,self.output_layouts,self.use_local_output),)
[docs]classSequenceParallel(ParallelStyle):""" SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the `RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__ This style implements the operation that is described in the paper `Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__ If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would redistribute the input to be sharded on the sequence dimension. The output of the ``nn.Module`` will be sharded on the sequence dimension. Keyword Args: sequence_dim (int, optional): The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to become a DTensor that is sharded on the sequence dimension, default: 1. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False. Returns: A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), >>> ... .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e. ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom inits for the weights on those modules, you need to broadcast the weights before/after parallelizing to ensure that they are replicated. """def__init__(self,*,sequence_dim:int=1,use_local_output:bool=False):super().__init__()self.sequence_sharding=(Shard(sequence_dim),)self.use_local_output=use_local_outputdef_replicate_module_fn(self,name:str,module:nn.Module,device_mesh:DeviceMesh):forp_name,paraminmodule.named_parameters():# simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow# us to simply just use from_localreplicated_param=torch.nn.Parameter(DTensor.from_local(param,device_mesh,[Replicate()],run_check=False))module.register_parameter(p_name,replicated_param)@staticmethoddef_prepare_input_fn(sequence_sharding,mod,inputs,device_mesh):input_tensor=inputs[0]ifisinstance(input_tensor,DTensor):# if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute itifinput_tensor.placements!=sequence_sharding:input_tensor=input_tensor.redistribute(placements=sequence_sharding,async_op=True)returninput_tensorelifisinstance(input_tensor,torch.Tensor):# assume the input passed in already sharded on the sequence dim and create the DTensorreturnDTensor.from_local(input_tensor,device_mesh,sequence_sharding,run_check=False)else:raiseValueError(f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}")@staticmethoddef_prepare_output_fn(use_local_output,mod,outputs,device_mesh):returnoutputs.to_local()ifuse_local_outputelseoutputsdef_apply(self,module:nn.Module,device_mesh:DeviceMesh)->nn.Module:returndistribute_module(module,device_mesh,self._replicate_module_fn,partial(self._prepare_input_fn,self.sequence_sharding),partial(self._prepare_output_fn,self.use_local_output),)
[docs]classPrepareModuleInput(ParallelStyle):""" Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to ``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``. Keyword Args: input_layouts (Union[Placement, Tuple[Optional[Placement]]]): The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified as a placeholder. default: None. desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]): The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None. input_kwarg_layouts (Dict[str, Placement]): The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. default: None desired_input_kwarg_layouts: (Dict[str, Placement]): The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. default: None. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False. Returns: A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor >>> # and then redistributed to Replicated DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan={ >>> "attn": PrepareModuleInput( >>> input_layouts=(Shard(0), None, None, ...), >>> desired_input_layouts=(Replicate(), None, None, ...) >>> ), >>> } >>> ) """def__init__(self,*,input_layouts:Optional[Union[Placement,Tuple[Optional[Placement]]]]=None,desired_input_layouts:Optional[Union[Placement,Tuple[Optional[Placement]]]]=None,input_kwarg_layouts:Optional[Dict[str,Placement]]=None,desired_input_kwarg_layouts:Optional[Dict[str,Placement]]=None,use_local_output:bool=False,):self.input_layouts=((input_layouts,)ifisinstance(input_layouts,Placement)elseinput_layouts)self.desired_input_layouts=((desired_input_layouts,)ifisinstance(desired_input_layouts,Placement)elsedesired_input_layouts)self.use_local_output=use_local_outputifself.input_layoutsisnotNone:assert(self.desired_input_layoutsisnotNone),"desired module inputs should not be None!"assertlen(self.input_layouts)==len(self.desired_input_layouts),"input_layouts and desired_input_layouts should have same length!"self.with_kwargs=input_kwarg_layoutsisnotNoneself.input_kwarg_layouts=input_kwarg_layoutsor{}self.desired_input_kwarg_layouts=desired_input_kwarg_layoutsor{}ifself.with_kwargs:assertlen(self.input_kwarg_layouts)==len(self.desired_input_kwarg_layouts),"input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"def_prepare_input_arg(self,input:Any,mesh:DeviceMesh,input_layout:Optional[Placement],desired_layout:Optional[Placement],):ifinput_layoutisnotNone:ifisinstance(input,DTensor):# TODO: re-enable the check once we fix the compile path# assert inp.placements[0] == input_layoutdt_inp=inputelse:assertisinstance(input,torch.Tensor),"expecting input to be a torch.Tensor!"dt_inp=DTensor.from_local(input,mesh,(input_layout,),run_check=False)ifdesired_layoutisnotNoneandinput_layout!=desired_layout:dt_inp=dt_inp.redistribute(placements=(desired_layout,))returndt_inp.to_local()ifself.use_local_outputelsedt_inpelse:returninputdef_prepare_input_fn(self,inputs,device_mesh):ifself.input_layoutsisNone:returninputsprepared_inputs=[]ifnotisinstance(inputs,tuple):inputs=(inputs,)iflen(inputs)!=len(self.input_layouts):raiseValueError("module inputs and input_layouts should have same length!")assert(self.desired_input_layoutsisnotNone),"desired module inputs should not be None!"forinp,input_layout,desired_layoutinzip(inputs,self.input_layouts,self.desired_input_layouts):prepared_inputs.append(self._prepare_input_arg(inp,device_mesh,input_layout,desired_layout))returntuple(prepared_inputs)def_prepare_input_kwarg_fn(self,inputs,kwarg_inputs,device_mesh):prepared_arg_inputs=self._prepare_input_fn(inputs,device_mesh)prepared_kwarg_inputs={}forkwarg_keyinkwarg_inputs.keys():kwarg_val=kwarg_inputs[kwarg_key]input_layout=self.input_kwarg_layouts.get(kwarg_key)desired_input_layout=self.desired_input_kwarg_layouts.get(kwarg_key)prepared_kwarg_inputs[kwarg_key]=self._prepare_input_arg(kwarg_val,device_mesh,input_layout,desired_input_layout)return(prepared_arg_inputs,prepared_kwarg_inputs)def_apply(self,module:nn.Module,device_mesh:DeviceMesh)->nn.Module:ifself.with_kwargs:module.register_forward_pre_hook(lambda_,inputs,kwargs:self._prepare_input_kwarg_fn(inputs,kwargs,device_mesh),with_kwargs=True,)# type: ignore[misc]else:module.register_forward_pre_hook(lambda_,inputs:self._prepare_input_fn(inputs,device_mesh))# type: ignore[misc, call-arg]returnmodule
[docs]classPrepareModuleOutput(ParallelStyle):""" Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to ``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``. Keyword Args: output_layouts (Union[Placement, Tuple[Placement]]): The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified as a placeholder. desired_output_layouts (Union[Placement, Tuple[Placement]]): The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module have the desired DTensor layouts. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True. Returns: A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor >>> # and then redistributed to Sharded DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan = PrepareModuleOutput( >>> output_layouts=Replicate(), >>> desired_output_layouts=Shard(0) >>> ) >>> ) """def__init__(self,*,output_layouts:Union[Placement,Tuple[Placement]],desired_output_layouts:Union[Placement,Tuple[Placement]],use_local_output:bool=True,):self.output_layouts=((output_layouts,)ifisinstance(output_layouts,Placement)elseoutput_layouts)self.desired_output_layouts=((desired_output_layouts,)ifisinstance(desired_output_layouts,Placement)elsedesired_output_layouts)self.use_local_output=use_local_outputassertlen(self.output_layouts)==len(self.desired_output_layouts),"output_layouts and desired_output_layouts should have same length!"def_prepare_out_fn(self,outputs,device_mesh):prepared_outputs=[]ifnotisinstance(outputs,tuple):outputs=(outputs,)iflen(outputs)!=len(self.output_layouts):raiseValueError("module outputs and output_layouts should have same length!")forout,out_layout,desired_out_layoutinzip(outputs,self.output_layouts,self.desired_output_layouts):ifout_layoutisnotNone:ifisinstance(out,DTensor):# TODO: re-enable the check once we fix the compile path# assert out.placements[0] == out_layoutdt_out=outelse:dt_out=DTensor.from_local(out,device_mesh,(out_layout,),run_check=False)ifout_layout!=desired_out_layout:dt_out=dt_out.redistribute(placements=(desired_out_layout,))prepared_outputs.append(dt_out.to_local()ifself.use_local_outputelsedt_out)else:prepared_outputs.append(out)iflen(prepared_outputs)==1:returnprepared_outputs[0]else:returntuple(prepared_outputs)def_apply(self,module:nn.Module,device_mesh:DeviceMesh)->nn.Module:module.register_forward_hook(lambda_,inputs,outputs:self._prepare_out_fn(outputs,device_mesh))# type: ignore[misc, call-arg]returnmodule
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.