• Docs >
  • Tensor Parallelism - torch.distributed.tensor.parallel
Shortcuts

Tensor Parallelism - torch.distributed.tensor.parallel

Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor (DTensor) and provides different parallelism styles: Colwise and Rowwise Parallelism.

Warning

Tensor Parallelism APIs are experimental and subject to change.

The entrypoint to parallelize your nn.Module using Tensor Parallelism is:

torch.distributed.tensor.parallel.parallelize_module(module, device_mesh, parallelize_plan, tp_mesh_dim=0)[source]

Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.

We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains ParallelStyle, which indicates how user wants the module or sub_module to be parallelized.

User can also specify different parallel style per module fully qualified name (FQN).

Note that parallelize_module only accepts a 1-D DeviceMesh, if you have a 2-D or N-D DeviceMesh, slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. device_mesh["tp"])

Parameters
  • module (nn.Module) – Module to be parallelized.

  • device_mesh (DeviceMesh) – Object which describes the mesh topology of devices for the DTensor.

  • parallelize_plan (Union[ParallelStyle, Dict[str, ParallelStyle]]) – The plan used to parallelize the module. It can be either a ParallelStyle object which contains how we prepare input/output for Tensor Parallelism or it can be a dict of module FQN and its corresponding ParallelStyle object.

  • tp_mesh_dim (int, deprecated) – The dimension of device_mesh where we perform Tensor Parallelism on, this field is deprecated and will be removed in future. If you have a 2-D or N-D DeviceMesh, consider passing in device_mesh[“tp”]

Returns

A nn.Module object parallelized.

Return type

Module

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>>
>>> # Define the module.
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>>

Note

For complex module architecture like Attention, MLP layers, we recommend composing different ParallelStyles together (i.e. ColwiseParallel and RowwiseParallel) and pass as a parallelize_plan, to achieves the desired sharding computation.

Tensor Parallelism supports the following parallel styles:

class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]

Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding. Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention)

Keyword Arguments
  • input_layouts (Placement, optional) – The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be replicated.

  • output_layouts (Placement, optional) – The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is sharded on the last dimension.

  • use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module output, default: True.

Returns

A ParallelStyle object that represents Colwise sharding of the nn.Module.

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> ...
>>> # By default, the input of the "w1" Linear will be annotated to Replicated DTensor
>>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
>>>>
>>> parallelize_module(
>>>     module=block, # this can be a submodule or module
>>>     ...,
>>>     parallelize_plan={"w1": ColwiseParallel()},
>>> )
>>> ...

Note

By default ColwiseParallel output is sharded on the last dimension if the output_layouts not specified, if there’re operators that require specific tensor shape (i.e. before the paired RowwiseParallel), keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.

class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]

Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear only. Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention)

Keyword Arguments
  • input_layouts (Placement, optional) – The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.

  • output_layouts (Placement, optional) – The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is replicated.

  • use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module output, default: True.

Returns

A ParallelStyle object that represents Rowwise sharding of the nn.Module.

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
>>> ...
>>> # By default, the input of the "w2" Linear will be annotated to DTensor that shards on the last dim
>>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
>>>
>>> parallelize_module(
>>>     module=block, # this can be a submodule or module
>>>     ...,
>>>     parallelize_plan={"w2": RowwiseParallel()},
>>> )
>>> ...

To simply configure the nn.Module’s inputs and outputs with DTensor layouts and perform necessary layout redistributions, without distribute the module parameters to DTensors, the following classes can be used in the parallelize_plan of parallelize_module:

class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts, desired_input_layouts, use_local_output=False)[source]

Configure the nn.Module’s inputs to convert the input tensors of the nn.Module to DTensors at runtime according to input_layouts, and perform layout redistribution according to the desired_input_layouts.

Keyword Arguments
  • input_layouts (Union[Placement, Tuple[Placement]]) – The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, None need to be specified as a placeholder.

  • desired_input_layouts (Union[Placement, Tuple[Placement]]) – The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. This argument needs to have the same length with input_layouts.

  • use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module inputs, default: False.

Returns

A ParallelStyle object that prepares the sharding layouts of the nn.Module’s inputs.

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> ...
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
>>> # and then redistributed to Replicated DTensor.
>>> parallelize_module(
>>>     module=block, # this can be a submodule or module
>>>     ...,
>>>     parallelize_plan={
>>>         "attn": PrepareModuleInput(
>>>             input_layouts=(Shard(0), None, None, ...),
>>>             desired_input_layouts=(Replicate(), None, None, ...)
>>>         ),
>>>     }
>>> )
class torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)[source]

Configure the nn.Module’s outputs to convert the output tensors of the nn.Module to DTensors at runtime according to output_layouts, and perform layout redistribution according to the desired_output_layouts.

Keyword Arguments
  • output_layouts (Union[Placement, Tuple[Placement]]) – The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to DTensors if they are torch.Tensor. If some outputs are not torch.Tensor or no need to convert to DTensors, None need to be specified as a placeholder.

  • desired_output_layouts (Union[Placement, Tuple[Placement]]) – The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module have the desired DTensor layouts.

  • use_local_output (bool, optional) – Whether to use local torch.Tensor instead of DTensor for the module outputs, default: False.

Returns

A ParallelStyle object that prepares the sharding layouts of the nn.Module’s outputs.

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
>>> ...
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
>>> # and then redistributed to Replicated DTensor.
>>> parallelize_module(
>>>     module=block, # this can be a submodule or module
>>>     ...,
>>>     parallelize_plan={
>>>         "submodule": PrepareModuleOutput(
>>>             output_layouts=Replicate(),
>>>             desired_output_layouts=Shard(0)
>>>         ),
>>>     }
>>> )

For models like Transformer, we recommend users to use ColwiseParallel and RowwiseParallel together in the parallelize_plan for achieve the desired sharding for the entire model (i.e. Attention and MLP).

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources