Source code for torch.distributed.tensor.parallel.style
# Copyright (c) Meta Platforms, Inc. and affiliatesfromabcimportABC,abstractmethodfromtypingimportOptional,Union,Tuplefromfunctoolsimportpartialimporttorch.nnasnnfromtorch.distributed._tensorimportDeviceMesh,DTensor,Placement,Replicate,Shard,distribute_tensor,distribute_module__all__=["ParallelStyle","RowwiseParallel","ColwiseParallel","PrepareModuleInput","PrepareModuleOutput",]classParallelStyle(ABC):""" The parallel style contract defines how the module or submodule should be parallelized. It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum flexibility for different kind of style implementations. """@abstractmethoddef_apply(self,module:nn.Module,device_mesh:DeviceMesh)->nn.Module:...
[docs]classColwiseParallel(ParallelStyle):""" Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding. Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention) Keyword Args: input_layouts (Placement, optional): The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be replicated. output_layouts (Placement, optional): The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is sharded on the last dimension. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. Returns: A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> ... >>> # By default, the input of the "w1" Linear will be annotated to Replicated DTensor >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim. >>>> >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan={"w1": ColwiseParallel()}, >>> ) >>> ... .. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``), keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size. """def__init__(self,*,input_layouts:Optional[Placement]=None,output_layouts:Optional[Placement]=None,use_local_output:bool=True):super().__init__()self.input_layouts=(input_layoutsorReplicate(),)self.output_layouts=(output_layoutsorShard(-1),)# colwise linear runtime sharding (desired sharding):# 1. requires replicate input# 2. shard output on last dimself.desired_input_layouts=(Replicate(),)self.use_local_output=use_local_output@staticmethoddef_prepare_input_fn(input_layouts,desired_input_layouts,inputs,device_mesh):# TODO: figure out dynamo support for instance method and switch this to instance method# annotate module input placements/sharding with input_layoutsinput_tensor=inputs[0]ifnotisinstance(input_tensor,DTensor):input_tensor=DTensor.from_local(input_tensor,device_mesh,input_layouts,run_check=False)# transform the input layouts to the desired layouts of ColwiseParallelifinput_layouts!=desired_input_layouts:input_tensor=input_tensor.redistribute(placements=desired_input_layouts)returninput_tensordef_partition_fn(self,name,module,device_mesh):ifisinstance(module,nn.Linear):# colwise shard weight/bias to Shard(0), weight be Shard(0)# means Colwise as Linear is input * weight^T + bias, where# weight would become Shard(1)forname,paraminmodule.named_parameters():dist_param=nn.Parameter(distribute_tensor(param,device_mesh,[Shard(0)]))module.register_parameter(name,dist_param)elifisinstance(module,nn.Embedding):# colwise shard embedding.weight is straight forward as Shard(1)forname,paraminmodule.named_parameters():dist_param=nn.Parameter(distribute_tensor(param,device_mesh,[Shard(1)]))module.register_parameter(name,dist_param)else:raiseNotImplementedError("ColwiseParallel only supports nn.Linear"f"and nn.Embedding for now, but found {type(module)}!")@staticmethoddef_prepare_output_fn(output_layouts,use_local_output,outputs,device_mesh):# outputs is a shard on last dimension DTensor, i.e. Shard(-1)outputs=outputs.redistribute(placements=output_layouts)# back to local tensorreturnoutputs.to_local()ifuse_local_outputelseoutputsdef_apply(self,module:nn.Module,device_mesh:DeviceMesh)->nn.Module:returndistribute_module(module,device_mesh,self._partition_fn,partial(self._prepare_input_fn,self.input_layouts,self.desired_input_layouts),partial(self._prepare_output_fn,self.output_layouts,self.use_local_output),)
[docs]classRowwiseParallel(ParallelStyle):""" Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear only. Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention) Keyword Args: input_layouts (Placement, optional): The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension. output_layouts (Placement, optional): The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is replicated. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. Returns: A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel >>> ... >>> # By default, the input of the "w2" Linear will be annotated to DTensor that shards on the last dim >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`. >>> >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan={"w2": RowwiseParallel()}, >>> ) >>> ... """def__init__(self,*,input_layouts:Optional[Placement]=None,output_layouts:Optional[Placement]=None,use_local_output:bool=True):super().__init__()self.input_layouts=(input_layoutsorShard(-1),)self.output_layouts=(output_layoutsorReplicate(),)# rowwise linear runtime sharding:# 1. shard input on last dim# 2. partial output, to replicate -> allreduce, to shard -> reduce_scatterself.desired_input_layouts=(Shard(-1),)self.use_local_output=use_local_output@staticmethoddef_prepare_input_fn(input_layouts,desired_input_layouts,inputs,device_mesh):input_tensor=inputs[0]ifnotisinstance(input_tensor,DTensor):input_tensor=DTensor.from_local(input_tensor,device_mesh,input_layouts,run_check=False)ifinput_layouts!=desired_input_layouts:input_tensor=input_tensor.redistribute(placements=desired_input_layouts)returninput_tensordef_partition_fn(self,name,module,device_mesh):ifisinstance(module,nn.Linear):# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)# means Rowwise as Linear is input * weight^T + bias, where# weight would become Shard(0)module.register_parameter("weight",nn.Parameter(distribute_tensor(module.weight,device_mesh,[Shard(1)])))ifmodule.biasisnotNone:module.register_parameter("bias",nn.Parameter(distribute_tensor(module.bias,device_mesh,[Replicate()])))else:raiseNotImplementedError("RowwiseParallel currently only support nn.Linear!")@staticmethoddef_prepare_output_fn(output_layouts,use_local_output,outputs,device_mesh):outputs=outputs.redistribute(placements=output_layouts)# back to local tensor if use_local_output is Truereturnoutputs.to_local()ifuse_local_outputelseoutputsdef_apply(self,module:nn.Module,device_mesh:DeviceMesh)->nn.Module:returndistribute_module(module,device_mesh,self._partition_fn,partial(self._prepare_input_fn,self.input_layouts,self.desired_input_layouts),partial(self._prepare_output_fn,self.output_layouts,self.use_local_output),)
[docs]classPrepareModuleInput(ParallelStyle):""" Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to ``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``. Keyword Args: input_layouts (Union[Placement, Tuple[Placement]]): The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified as a placeholder. desired_input_layouts (Union[Placement, Tuple[Placement]]): The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False. Returns: A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> ... >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor >>> # and then redistributed to Replicated DTensor. >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan={ >>> "attn": PrepareModuleInput( >>> input_layouts=(Shard(0), None, None, ...), >>> desired_input_layouts=(Replicate(), None, None, ...) >>> ), >>> } >>> ) """def__init__(self,*,input_layouts:Union[Placement,Tuple[Placement]],desired_input_layouts:Union[Placement,Tuple[Placement]],use_local_output:bool=False):self.input_layouts=(input_layouts,)ifisinstance(input_layouts,Placement)elseinput_layoutsself.desired_input_layouts= \
(desired_input_layouts,)ifisinstance(desired_input_layouts,Placement)elsedesired_input_layoutsself.use_local_output=use_local_outputassertlen(self.input_layouts)==len(self.desired_input_layouts), \
"input_layouts and desired_input_layouts should have same length!"def_prepare_input_fn(self,inputs,device_mesh):prepared_inputs=[]ifnotisinstance(inputs,tuple):inputs=(inputs,)forinp,input_layout,desired_layoutinzip(inputs,self.input_layouts,self.desired_input_layouts):ifinput_layoutisnotNone:ifisinstance(inp,DTensor):assertinp.placements[0]==input_layoutdt_inp=inpelse:dt_inp=DTensor.from_local(inp,device_mesh,(input_layout,),run_check=False)ifinput_layout!=desired_layout:dt_inp=dt_inp.redistribute(placements=(desired_layout,))prepared_inputs.append(dt_inp.to_local()ifself.use_local_outputelsedt_inp)else:prepared_inputs.append(inp)returntuple(prepared_inputs)def_apply(self,module:nn.Module,device_mesh:DeviceMesh)->nn.Module:module.register_forward_pre_hook(lambda_,inputs:self._prepare_input_fn(inputs,device_mesh))# type: ignore[misc, call-arg]returnmodule
[docs]classPrepareModuleOutput(ParallelStyle):""" Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to ``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``. Keyword Args: output_layouts (Union[Placement, Tuple[Placement]]): The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified as a placeholder. desired_output_layouts (Union[Placement, Tuple[Placement]]): The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module have the desired DTensor layouts. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: False. Returns: A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput >>> ... >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor >>> # and then redistributed to Replicated DTensor. >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan={ >>> "submodule": PrepareModuleOutput( >>> output_layouts=Replicate(), >>> desired_output_layouts=Shard(0) >>> ), >>> } >>> ) """def__init__(self,*,output_layouts:Union[Placement,Tuple[Placement]],desired_output_layouts:Union[Placement,Tuple[Placement]],use_local_output:bool=True):self.output_layouts=(output_layouts,)ifisinstance(output_layouts,Placement)elseoutput_layoutsself.desired_output_layouts= \
(desired_output_layouts,)ifisinstance(desired_output_layouts,Placement)elsedesired_output_layoutsself.use_local_output=use_local_outputdef_prepare_out_fn(self,outputs,device_mesh):prepared_outputs=[]ifnotisinstance(outputs,tuple):outputs=(outputs,)forout,out_layout,desired_out_layoutinzip(outputs,self.output_layouts,self.desired_output_layouts):ifout_layoutisnotNone:ifisinstance(out,DTensor):assertout.placements[0]==out_layoutdt_out=outelse:dt_out=DTensor.from_local(out,device_mesh,(out_layout,),run_check=False)ifout_layout!=desired_out_layout:dt_out=dt_out.redistribute(placements=(desired_out_layout,))prepared_outputs.append(dt_out.to_local()ifself.use_local_outputelsedt_out)else:prepared_outputs.append(out)iflen(prepared_outputs)==1:returnprepared_outputs[0]else:returntuple(prepared_outputs)def_apply(self,module:nn.Module,device_mesh:DeviceMesh)->nn.Module:module.register_forward_hook(lambda_,inputs,outputs:self._prepare_out_fn(outputs,device_mesh))# type: ignore[misc, call-arg]returnmodule
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.