Shortcuts

Source code for torch.distributed.tensor.parallel.style

# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from abc import ABC, abstractmethod
from typing import Optional, Union, Tuple, Dict, Any
from functools import partial

import torch
import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor, distribute_module


__all__ = [
    "ParallelStyle",
    "RowwiseParallel",
    "SequenceParallel",
    "ColwiseParallel",
    "PrepareModuleInput",
    "PrepareModuleOutput",
]


class ParallelStyle(ABC):
    """
    The parallel style contract defines how the module or submodule should be parallelized.

    It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum
    flexibility for different kind of style implementations.
    """

    @abstractmethod
    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        ...


[docs]class ColwiseParallel(ParallelStyle): """ Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding. Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention) Keyword Args: input_layouts (Placement, optional): The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be replicated. output_layouts (Placement, optional): The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is sharded on the last dimension. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. Returns: A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()}) >>> ... .. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``), keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size. """ def __init__( self, *, input_layouts: Optional[Placement] = None, output_layouts: Optional[Placement] = None, use_local_output: bool = True ): super().__init__() self.input_layouts = (input_layouts or Replicate(), ) self.output_layouts = (output_layouts or Shard(-1), ) # colwise linear runtime sharding (desired sharding): # 1. requires replicate input # 2. shard output on last dim self.desired_input_layouts = (Replicate(), ) self.use_local_output = use_local_output @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): # TODO: figure out dynamo support for instance method and switch this to instance method # annotate module input placements/sharding with input_layouts input_tensor = inputs[0] if not isinstance(input_tensor, DTensor): input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) return input_tensor def _partition_linear_fn(self, name, module, device_mesh): # colwise shard weight/bias to Shard(0), weight be Shard(0) # means Colwise as Linear is input * weight^T + bias, where # weight would become Shard(1) for name, param in module.named_parameters(): dist_param = nn.Parameter( distribute_tensor(param, device_mesh, [Shard(0)]) ) module.register_parameter(name, dist_param) def _partition_embedding_fn(self, name, module, device_mesh): # colwise shard embedding.weight is straight forward as Shard(1) for name, param in module.named_parameters(): dist_param = nn.Parameter( distribute_tensor(param, device_mesh, [Shard(1)]) ) module.register_parameter(name, dist_param) @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): # outputs is a shard on last dimension DTensor, i.e. Shard(-1) if outputs.placements != output_layouts: outputs = outputs.redistribute(placements=output_layouts, async_op=True) # back to local tensor return outputs.to_local() if use_local_output else outputs def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: if isinstance(module, nn.Linear): partition_fn = self._partition_linear_fn elif isinstance(module, nn.Embedding): partition_fn = self._partition_embedding_fn else: raise NotImplementedError("ColwiseParallel currently only support nn.Linear and nn.Embedding!") return distribute_module( module, device_mesh, partition_fn, partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), )
[docs]class RowwiseParallel(ParallelStyle): """ Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding. Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules. (i.e. MLP, Attention) Keyword Args: input_layouts (Placement, optional): The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension. output_layouts (Placement, optional): The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module with the user desired layout. If not specified, the output tensor is replicated. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. Returns: A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), >>> ... """ def __init__( self, *, input_layouts: Optional[Placement] = None, output_layouts: Optional[Placement] = None, use_local_output: bool = True ): super().__init__() self.input_layouts = (input_layouts or Shard(-1), ) self.output_layouts = (output_layouts or Replicate(), ) self.use_local_output = use_local_output @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): input_tensor = inputs[0] if not isinstance(input_tensor, DTensor): input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) if input_layouts != desired_input_layouts: input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) return input_tensor def _partition_linear_fn(self, name, module, device_mesh): # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) # means Rowwise as nn.Linear is input * weight^T + bias, where # weight would become Shard(0) module.register_parameter("weight", nn.Parameter( distribute_tensor(module.weight, device_mesh, [Shard(1)]) )) if module.bias is not None: module.register_parameter("bias", nn.Parameter( distribute_tensor(module.bias, device_mesh, [Replicate()]) )) def _partition_embedding_fn(self, name, module, device_mesh): # rowwise shard embedding.weight is Shard(0) for name, param in module.named_parameters(): dist_param = nn.Parameter( distribute_tensor(param, device_mesh, [Shard(0)]) ) module.register_parameter(name, dist_param) @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): # Rowwise sharding produces partial output, depending on output layouts: # 1. to replicate -> allreduce # 2. to shard -> reduce_scatter if outputs.placements != output_layouts: outputs = outputs.redistribute(placements=output_layouts, async_op=True) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: if isinstance(module, nn.Linear): partition_fn = self._partition_linear_fn # rowwise linear runtime sharding requires input tensor shard on last dim self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1), ) elif isinstance(module, nn.Embedding): partition_fn = self._partition_embedding_fn # rowwise embedding runtime sharding requires input tensor replicated self.desired_input_layouts = (Replicate(), ) else: raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!") return distribute_module( module, device_mesh, partition_fn, partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), )
[docs]class SequenceParallel(ParallelStyle): """ SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the `RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__ This style implements the operation that is described in the paper `Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__ Both the input and output of the ``nn.Module`` will be sharded on the sequence dimension. Keyword Args: sequence_dim (int, optional): The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to become a DTensor that is sharded on the sequence dimension, default: 1. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False. Returns: A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), >>> ... .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e. ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom inits for the weights on those modules, you need to broadcast the weights before/after parallelizing to ensure that they are replicated. """ def __init__( self, *, sequence_dim: int = 1, use_local_output: bool = False ): super().__init__() self.sequence_dim = sequence_dim self.use_local_output = use_local_output def _replicate_module_fn(self, name: str, module: nn.Module, device_mesh: DeviceMesh): for p_name, param in module.named_parameters(): # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow # us to simply just use from_local replicated_param = torch.nn.Parameter( DTensor.from_local(param, device_mesh, [Replicate()], run_check=False) ) module.register_parameter(p_name, replicated_param) @staticmethod def _prepare_input_fn(sequence_dim, mod, inputs, device_mesh): input_tensor = inputs[0] if isinstance(input_tensor, DTensor): return inputs elif isinstance(input_tensor, torch.Tensor): return DTensor.from_local(input_tensor, device_mesh, [Shard(sequence_dim)], run_check=False) else: raise ValueError(f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}") @staticmethod def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): return outputs.to_local() if use_local_output else outputs def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: return distribute_module( module, device_mesh, self._replicate_module_fn, partial(self._prepare_input_fn, self.sequence_dim), partial(self._prepare_output_fn, self.use_local_output), )
[docs]class PrepareModuleInput(ParallelStyle): """ Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to ``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``. Keyword Args: input_layouts (Union[Placement, Tuple[Optional[Placement]]]): The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified as a placeholder. default: None. desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]): The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None. input_kwarg_layouts (Dict[str, Placement]): The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. default: None desired_input_kwarg_layouts: (Dict[str, Placement]): The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module have the desired DTensor layouts. default: None. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False. Returns: A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor >>> # and then redistributed to Replicated DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan={ >>> "attn": PrepareModuleInput( >>> input_layouts=(Shard(0), None, None, ...), >>> desired_input_layouts=(Replicate(), None, None, ...) >>> ), >>> } >>> ) """ def __init__( self, *, input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, desired_input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, input_kwarg_layouts: Optional[Dict[str, Placement]] = None, desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None, use_local_output: bool = False ): self.input_layouts = (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts self.desired_input_layouts = \ (desired_input_layouts,) if isinstance(desired_input_layouts, Placement) else desired_input_layouts self.use_local_output = use_local_output if self.input_layouts is not None: assert self.desired_input_layouts is not None, "desired module inputs should not be None!" assert len(self.input_layouts) == len(self.desired_input_layouts), \ "input_layouts and desired_input_layouts should have same length!" self.with_kwargs = input_kwarg_layouts is not None self.input_kwarg_layouts = input_kwarg_layouts or {} self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} if self.with_kwargs: assert len(self.input_kwarg_layouts) == len(self.desired_input_kwarg_layouts), \ "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" def _prepare_input_arg( self, input: Any, mesh: DeviceMesh, input_layout: Optional[Placement], desired_layout: Optional[Placement] ): if input_layout is not None: if isinstance(input, DTensor): # TODO: re-enable the check once we fix the compile path # assert inp.placements[0] == input_layout dt_inp = input else: assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!" dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) return dt_inp.to_local() if self.use_local_output else dt_inp else: return input def _prepare_input_fn(self, inputs, device_mesh): if self.input_layouts is None: return inputs prepared_inputs = [] if not isinstance(inputs, tuple): inputs = (inputs,) if len(inputs) != len(self.input_layouts): raise ValueError("module inputs and input_layouts should have same length!") assert self.desired_input_layouts is not None, "desired module inputs should not be None!" for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts): prepared_inputs.append(self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)) return tuple(prepared_inputs) def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) prepared_kwarg_inputs = {} for kwarg_key in kwarg_inputs.keys(): kwarg_val = kwarg_inputs[kwarg_key] input_layout = self.input_kwarg_layouts.get(kwarg_key) desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(kwarg_val, device_mesh, input_layout, desired_input_layout) return (prepared_arg_inputs, prepared_kwarg_inputs) def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: if self.with_kwargs: module.register_forward_pre_hook( lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(inputs, kwargs, device_mesh), with_kwargs=True ) # type: ignore[misc] else: module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg] return module
[docs]class PrepareModuleOutput(ParallelStyle): """ Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to ``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``. Keyword Args: output_layouts (Union[Placement, Tuple[Placement]]): The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified as a placeholder. desired_output_layouts (Union[Placement, Tuple[Placement]]): The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module have the desired DTensor layouts. use_local_output (bool, optional): Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True. Returns: A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor >>> # and then redistributed to Sharded DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan = PrepareModuleOutput( >>> output_layouts=Replicate(), >>> desired_output_layouts=Shard(0) >>> ) >>> ) """ def __init__( self, *, output_layouts: Union[Placement, Tuple[Placement]], desired_output_layouts: Union[Placement, Tuple[Placement]], use_local_output: bool = True ): self.output_layouts = (output_layouts,) if isinstance(output_layouts, Placement) else output_layouts self.desired_output_layouts = \ (desired_output_layouts,) if isinstance(desired_output_layouts, Placement) else desired_output_layouts self.use_local_output = use_local_output assert len(self.output_layouts) == len(self.desired_output_layouts), \ "output_layouts and desired_output_layouts should have same length!" def _prepare_out_fn(self, outputs, device_mesh): prepared_outputs = [] if not isinstance(outputs, tuple): outputs = (outputs,) if len(outputs) != len(self.output_layouts): raise ValueError("module outputs and output_layouts should have same length!") for out, out_layout, desired_out_layout in zip(outputs, self.output_layouts, self.desired_output_layouts): if out_layout is not None: if isinstance(out, DTensor): # TODO: re-enable the check once we fix the compile path # assert out.placements[0] == out_layout dt_out = out else: dt_out = DTensor.from_local(out, device_mesh, (out_layout,), run_check=False) if out_layout != desired_out_layout: dt_out = dt_out.redistribute(placements=(desired_out_layout,)) prepared_outputs.append(dt_out.to_local() if self.use_local_output else dt_out) else: prepared_outputs.append(out) if len(prepared_outputs) == 1: return prepared_outputs[0] else: return tuple(prepared_outputs) def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)) # type: ignore[misc, call-arg] return module

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources