Shortcuts

Pipeline Parallelism

Note

torch.distributed.pipelining is a package migrated from the PiPPy project. It is currently in alpha state and under extensive development. For examples that work with our APIs, please refer to PiPPy’s examples directory.

Why Pipeline Parallel?

One of the most important techniques for advancing the state of the art in deep learning is scaling. Common techniques for scaling neural networks include data parallelism, tensor/operation parallelism, and pipeline parallelism (or pipelining). Pipelining is a technique in which the code of the model is partitioned and multiple micro-batches execute different parts of the model code concurrently. In many cases, pipeline parallelism can be an effective technique for scaling, in particular for large-scale jobs or bandwidth-limited interconnects. To learn more about pipeline parallelism in deep learning, see this article.

What is torch.distributed.pipelining?

While promising for scaling, pipelining is often difficult to implement, requiring intrusive code changes to model code and difficult-to-implement runtime orchestration code. torch.distributed.pipelining aims to provide a toolkit that does said things automatically to allow high-productivity scaling of models. It consists of a compiler and a runtime stack for easy pipelining of PyTorch models. In particular, it provides the following features:

  • Splitting of model code based on your specification. The goal is for the user to provide model code as-is to the system for parallelization, without having to make heavyweight modifications to make parallelism work. The specification is also simple.

  • Support for rich pipeline scheduling paradigms, including GPipe, 1F1B, Interleaved 1F1B and Looped BFS. It will be also easy to customize your own schedule under this framework.

  • First-class support for cross-host pipeline parallelism, as this is where PP is typically used (over slower interconnects).

  • Composability with other PyTorch parallel schemes such as data parallelism (DDP, FSDP) or tensor parallelism (overall, known as “3d parallelism”).

Examples

In the PiPPy repo where this package is migrated from, we provide rich examples based on realistic models. In particular, we show how to apply pipelining without any model code change. You can refer to the HuggingFace examples directory. Popular examples include: GPT2, and LLaMA.

Techniques Explained

torch.distributed.pipelining consists of two parts: a compiler and a runtime. The compiler takes your model code, splits it up, and transforms it into a Pipe, which is a wrapper that describes the model at each pipeline stage and their data-flow relationship. The runtime executes the PipelineStage in parallel, handling things like micro-batch splitting, scheduling, communication, and gradient propagation, etc. We will cover the APIs for these concepts in this section.

Splitting a Model with pipeline

To see how we can split a model into a pipeline, let’s first take an example trivial neural network:

import torch

class MyNetworkBlock(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.lin = torch.nn.Linear(in_dim, out_dim)

    def forward(self, x):
        x = self.lin(x)
        x = torch.relu(x)
        return x


class MyNetwork(torch.nn.Module):
    def __init__(self, in_dim, layer_dims):
        super().__init__()

        prev_dim = in_dim
        for i, dim in enumerate(layer_dims):
            setattr(self, f'layer{i}', MyNetworkBlock(prev_dim, dim))
            prev_dim = dim

        self.num_layers = len(layer_dims)
        # 10 output classes
        self.output_proj = torch.nn.Linear(layer_dims[-1], 10)

    def forward(self, x):
        for i in range(self.num_layers):
            x = getattr(self, f'layer{i}')(x)

        return self.output_proj(x)


in_dim = 512
layer_dims = [512, 1024, 256]
mn = MyNetwork(in_dim, layer_dims).to(device)

This network is written as free-form Python code; it has not been modified for any specific parallelism technique.

Let us see our usage of the pipeline interface:

from torch.distributed.pipelining import annotate_split_points, pipeline, Pipe, SplitPoint

annotate_split_points(mn, {'layer0': SplitPoint.END,
                          'layer1': SplitPoint.END})

batch_size = 32
example_input = torch.randn(batch_size, in_dim, device=device)
chunks = 4

pipe = pipeline(mn, chunks, example_args=(example_input,))
print(pipe)
************************************* pipe *************************************
GraphModule(
  (submod_0): GraphModule(
    (layer0): InterpreterModule(
      (lin): InterpreterModule()
    )
  )
  (submod_1): GraphModule(
    (layer1): InterpreterModule(
      (lin): InterpreterModule()
    )
  )
  (submod_2): GraphModule(
    (layer2): InterpreterModule(
      (lin): InterpreterModule()
    )
    (output_proj): InterpreterModule()
  )
)

def forward(self, arg8_1):
    submod_0 = self.submod_0(arg8_1);  arg8_1 = None
    submod_1 = self.submod_1(submod_0);  submod_0 = None
    submod_2 = self.submod_2(submod_1);  submod_1 = None
    return (submod_2,)

So what’s going on here? First, pipeline turns our model into a directed acyclic graph (DAG) by tracing the model. Then, it groups together the operations and parameters into pipeline stages. Stages are represented as submod_N submodules, where N is a natural number.

We used annotate_split_points to specify that the code should be split and the end of layer0 and layer1. Our code has thus been split into three pipeline stages. Our library also provides SplitPoint.BEGINNING if a user wants to split before certain annotation point.

While the annotate_split_points API gives users a way to specify the split points without modifying the model, our library also provides an API for in-model annotation: pipe_split(). For details, you can read this example.

This covers the basic usage of the Pipe API. For more information, please see the documentation.

Using PipelineSchedule for Execution

Given the above Pipe object, we can use one of the PipelineStage classes to execute our model in a pipelined fashion. First off, let us instantiate a PipelineStage instance:

# We are using `torchrun` to run this example with multiple processes.
# `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`.
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

# Initialize distributed environment
import torch.distributed as dist
dist.init_process_group(rank=rank, world_size=world_size)

# Pipeline stage is our main pipeline runtime. It takes in the pipe object,
# the rank of this process, and the device.
from torch.distributed.pipelining import PipelineStage
stage = PipelineStage(pipe, rank, device)

We can now attach the PipelineStage to a pipeline schedule, GPipe for example, and run with data:

from torch.distributed.pipelining import ScheduleGPipe
schedule = ScheduleGPipe(stage, chunks)

# Input data
x = torch.randn(batch_size, in_dim, device=device)

# Run the pipeline with input `x`. Divide the batch into 4 micro-batches
# and run them in parallel on the pipeline
if rank == 0:
    schedule.step(x)
else:
    output = schedule.step()

Note that since we split our model into three stages, we must run this script with three workers. For this example, we will use torchrun to run multiple processes within a single machine for demonstration purposes. We can collect up all of the code blocks above into a file named example.py and then run it with torchrun like so:

torchrun --nproc_per_node=3 example.py

Pipeline Transformation APIs

The following set of APIs transform your model into a pipeline representation.

class torch.distributed.pipelining.SplitPoint(value)[source]

An enumeration.

torch.distributed.pipelining.pipeline(module, num_chunks, example_args, example_kwargs=None, split_spec=None, split_policy=None)[source]

Creates a pipeline representation for the provided module.

See Pipe for more details.

Parameters
  • module (Module) – The module to be transformed into a Pipe.

  • num_chunks (int) – The number of microbatches to be run with this pipeline.

  • example_args (Tuple[Any, ...]) – Example positional inputs to be used with this pipeline.

  • example_kwargs (Optional[Dict[str, Any]]) – Example keyword inputs to be used with this pipeline. (default: None)

  • split_spec (Optional[Dict[str, SplitPoint]]) – A dictionary mapping module names to SplitPoint`s. (default: `None)

  • split_policy (Optional[Callable[[GraphModule], GraphModule]]) – The policy to use for splitting the module. (default: None)

Return type

A pipeline representation of class Pipe.

class torch.distributed.pipelining.Pipe(split_gm, splitter_qualname_map, num_stages, has_loss_and_backward, loss_spec, tracer_qualname_map=None)[source]
args_chunk_spec:

Chunking specification for positional inputs. (default: None)

kwargs_chunk_spec:

Chunking specification for keyword inputs. (default: None)

torch.distributed.pipelining.annotate_split_points(mod, spec)[source]
torch.distributed.pipelining.pipe_split()[source]

pipe_split is a special operator that is used to mark the boundary between stages in a module. It is used to split the module into stages. It is a no-op if your annotated module is run eagerly.

Example: >>> def forward(self, x): >>> x = torch.mm(x, self.mm_param) >>> x = torch.relu(x) >>> pipe_split() >>> x = self.lin(x) >>> return x

The above example will be split into two stages.

class torch.distributed.pipelining.ArgsChunkSpec(chunk_dims)[source]

Example: >>> # There are three positional arguments to the model, and >>> # we are chunking them along dimension 0, 0 and 1, respectively >>> with ArgsChunkSpec((0, 0, 1)): >>> pipe = pipeline(model, num_chunks, example_args)

class torch.distributed.pipelining.KwargsChunkSpec(chunk_dims)[source]

Example: >>> # Chunk dimension 0 for the “id” argument, 1 for the “mask” argument >>> with KwargsChunkSpec({“id”: 0, “mask”: 1}): >>> pipe = pipeline(model, num_chunks, (), example_kwargs)

Microbatch Utilities

class torch.distributed.pipelining.microbatch.TensorChunkSpec(split_dim)[source]
torch.distributed.pipelining.microbatch.split_args_kwargs_into_chunks(args, kwargs, chunks, args_chunk_spec=None, kwargs_chunk_spec=None)[source]

Given a sequence of args and kwargs, split them into a number of chunks according to their respective chunking specs.

Parameters
Returns

List of sharded args kwargs_split: List of sharded kwargs

Return type

args_split

torch.distributed.pipelining.microbatch.merge_chunks(chunks, chunk_spec)[source]

Given a list of chunks, merge them into a single value according to the chunk spec.

Parameters
  • chunks (List[Any]) – list of chunks

  • chunk_spec – Chunking spec for the chunks

Returns

Merged value

Return type

value

Pipeline Schedules

class torch.distributed.pipelining.PipelineSchedule.ScheduleGPipe(stage, n_microbatches, loss_fn=None, output_merge_spec=None)[source]

The GPipe schedule. Will go through all the microbatches in a fill-drain manner.

class torch.distributed.pipelining.PipelineSchedule.Schedule1F1B(stage, n_microbatches, loss_fn=None, output_merge_spec=None)[source]

The 1F1B schedule. Will perform one forward and one backward on the microbatches in steady state.

class torch.distributed.pipelining.PipelineSchedule.ScheduleInterleaved1F1B(stages, n_microbatches, loss_fn=None, output_merge_spec=None)[source]

The Interleaved 1F1B schedule. Will perform one forward and one backward on the microbatches in steady state and supports multiple stages per rank. When microbatches are ready for multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch (also called “depth first”).

class torch.distributed.pipelining.PipelineSchedule.ScheduleLoopedBFS(stages, n_microbatches, loss_fn=None, output_merge_spec=None)[source]

Breadth-First Pipeline Parallelism. See https://arxiv.org/abs/2211.05953 for details. Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. What is different is that when microbatches are ready for multiple local stages, Loops BFS will prioritizes the earlier stage, running all available microbatches at once.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources