• Tutorials >
  • Customize Process Group Backends Using Cpp Extensions

Customize Process Group Backends Using Cpp Extensions

Author: Feng Tian, Shen Li, Min Si


edit View and edit this tutorial in github.


This tutorial demonstrates how to implement a custom ProcessGroup backend and plug that into PyTorch distributed package using cpp extensions. This is helpful when you need a specialized software stack for your hardware, or when you would like to experiment with new collective communication algorithms.


PyTorch collective communications power several widely adopted distributed training features, including DistributedDataParallel, ZeroRedundancyOptimizer, FullyShardedDataParallel. In order to make the same collective communication API work with different communication backends, the distributed package abstracts collective communication operations into a ProcessGroup class. Different backends can then be implemented as subclasses of ProcessGroup using preferred third-party libraries. PyTorch distributed comes with three default backends, ProcessGroupNCCL, ProcessGroupGloo, and ProcessGroupMPI. However, beyond these three backends, there are also other communication libraries (e.g., UCC, OneCCL), different types of hardware (e.g., TPU, Trainum), and emerging communication algorithms (e.g., Herring, Reduction Server). Therefore, the distributed package exposes extension APIs to allow customizing collective communication backends.

The 4 steps below show how to implement a dummy ProcessGroup backend and use that in Python application code. Please note that this tutorial focuses on demonstrating the extension APIs, instead of developing a functioning communication backend. Hence, the dummy backend just covers a subset of the APIs (all_reduce and all_gather), and simply sets the values of tensors to 0.

Step 1: Implement a Subclass of ProcessGroup

This first step is to implement a ProcessGroup subclass that overrides target collective communication APIs and runs the custom communication algorithm. The extension also needs to implement a Work subclass, which serves as a future of communication results and allows asynchronous execution in application code. If the extension uses third-party libraries, it can include the headers and call into the library APIs from the ProcessGroupDummy subclass. The two code snippets below present the implementation of dummy.h and dummy.cpp. See the dummy collectives repository for the full implementation.

// file name: dummy.hpp
#include <torch/python.h>

#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>

#include <pybind11/chrono.h>

namespace c10d {

class ProcessGroupDummy : public ProcessGroup {
    ProcessGroupDummy(int rank, int size);

    c10::intrusive_ptr<Work> allgather(
        std::vector<std::vector<at::Tensor>>& outputTensors,
        std::vector<at::Tensor>& inputTensors,
        const AllgatherOptions& opts = AllgatherOptions()) override;

    c10::intrusive_ptr<Work> allreduce(
        std::vector<at::Tensor>& tensors,
        const AllreduceOptions& opts = AllreduceOptions()) override;

    // The collective communication APIs without a custom implementation
    // will error out if invoked by application code.

class WorkDummy : public Work {
      OpType opType,
      c10::intrusive_ptr<c10::ivalue::Future> future) // future of the output
      : Work(
          -1, // rank, only used by recvAnySource, irrelevant in this demo
      future_(std::move(future)) {}
    // There are several additional helper functions that need to be
    // implemented. Please refer to https://github.com/mrshenli/dummy_collectives
    // for the full implementation.

    c10::intrusive_ptr<c10::ivalue::Future> future_;
} // namespace c10d
// file name: dummy.cpp
#include "dummy.hpp"

namespace c10d {

// This is a dummy allgather that sets all output tensors to zero
// Modify the implementation to conduct real communication asynchronously
c10::intrusive_ptr<Work> ProcessGroupDummy::allgather(
        std::vector<std::vector<at::Tensor>>& outputTensors,
        std::vector<at::Tensor>& inputTensors,
        const AllgatherOptions& /* unused */) {
    for (auto& outputTensorVec : outputTensors) {
        for (auto& outputTensor : outputTensorVec) {

    auto future = c10::make_intrusive<c10::ivalue::Future>(
    return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future));

// This is a dummy allreduce that sets all output tensors to zero
// Modify the implementation to conduct real communication asynchronously
c10::intrusive_ptr<Work> ProcessGroupDummy::allreduce(
        std::vector<at::Tensor>& tensors,
        const AllreduceOptions& opts) {
    for (auto& tensor : tensors) {

    auto future = c10::make_intrusive<c10::ivalue::Future>(
    return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future));
} // namespace c10d

Step 2: Expose The Extension Python APIs

The backend constructors are called from Python side, so the extension also needs to expose the constructor APIs to Python. This can be done by adding the following methods. In this example, store and timeout are ignored by the ProcessGroupDummy instantiation method, as those are not used in this dummy implementation. However, real-world extensions should consider using the store to perform rendezvous and supporting the timeout argument.

class ProcessGroupDummy : public ProcessGroup {
    static c10::intrusive_ptr<ProcessGroup> createProcessGroupDummy(
        const c10::intrusive_ptr<::c10d::Store>& store,
        int rank,
        int size,
        const std::chrono::duration<float>& timeout);

    static void ProcessGroupDummyConstructor() __attribute__((constructor)) {
        py::object module = py::module::import("torch.distributed");
        py::object register_backend =
        // torch.distributed.Backend.register_backend will add `dummy` as a
        // new valid backend.
        register_backend("dummy", py::cpp_function(createProcessGroupDummy));
c10::intrusive_ptr<ProcessGroup> ProcessGroupDummy::createProcessGroupDummy(
        const c10::intrusive_ptr<::c10d::Store>& /* unused */,
        int rank,
        int size,
        const std::chrono::duration<float>& /* unused */) {
    return c10::make_intrusive<ProcessGroupDummy>(rank, size);

    m.def("createProcessGroupDummy", &ProcessGroupDummy::createProcessGroupDummy);

Step 3: Build The Custom Extension

Now, the extension source code files are ready. We can then use cpp extensions to build it. To do that, create a setup.py file that prepares the paths and commands. Then call python setup.py install to install the extension.

If the extension depends on third-party libraries, you can also specify libraries_dirs and libraries to the cpp extension APIs. See the torch ucc project as a real-world example.

# file name: setup.py
import os
import sys
import torch
from setuptools import setup
from torch.utils import cpp_extension

sources = ["src/dummy.cpp"]
include_dirs = [f"{os.path.dirname(os.path.abspath(__file__))}/include/"]

if torch.cuda.is_available():
    module = cpp_extension.CUDAExtension(
        name = "dummy_collectives",
        sources = sources,
        include_dirs = include_dirs,
    module = cpp_extension.CppExtension(
        name = "dummy_collectives",
        sources = sources,
        include_dirs = include_dirs,

    name = "Dummy-Collectives",
    version = "0.0.1",
    ext_modules = [module],
    cmdclass={'build_ext': cpp_extension.BuildExtension}

Step 4: Use The Extension in Application

After installation, you can conveniently use the dummy backend when calling init_process_group as if it is an builtin backend.

import os

import torch
# importing dummy_collectives makes torch.distributed recognize `dummy`
# as a valid backend.
import dummy_collectives

import torch.distributed as dist

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'

dist.init_process_group("dummy", rank=0, world_size=1)

x = torch.ones(6)
print(f"cpu allreduce: {x}")
if torch.cuda.is_available():
    y = x.cuda()
    print(f"cuda allreduce: {y}")

    dist.broadcast(x, 0)
except RuntimeError:
    print("got RuntimeError as broadcast is not implemented in Dummy ProcessGroup")


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources