Source code for torch.distributed.rpc.backend_registry

__all__ = ["init_backend", "backend_registered", "construct_rpc_backend_options", "register_backend", "BackendType", "BackendValue"]

import collections
import enum
from typing import cast, Dict, List, Set, Tuple

import torch
import torch.distributed as dist
from ._utils import _group_membership_management, _update_group_membership

from . import api
from . import constants as rpc_constants

BackendValue = collections.namedtuple(
    "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"]

def _backend_type_repr(self):
    return "BackendType." +

_backend_type_doc = """
    An enum class of available backends.

    PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend.
    Additional ones can be registered using the
    :func:`~torch.distributed.rpc.backend_registry.register_backend` function.

# Create an enum type, `BackendType`, with empty members.
# Can't handle Function Enum API (mypy bug #9079)
BackendType = enum.Enum(value="BackendType", names=dict())  # type: ignore[misc]
# Unable to assign a function a method (mypy bug #2427)
BackendType.__repr__ = _backend_type_repr  # type: ignore[assignment]

if BackendType.__doc__:
    BackendType.__doc__ = _backend_type_doc

def backend_registered(backend_name):
    Checks if backend_name is registered as an RPC backend.

        backend_name (str): string to identify the RPC backend.
        True if the backend has been registered with ``register_backend``, else
    return backend_name in BackendType.__members__.keys()

def register_backend(
    backend_name, construct_rpc_backend_options_handler, init_backend_handler
    """Registers a new RPC backend.

        backend_name (str): backend string to identify the handler.
        construct_rpc_backend_options_handler (function):
            Handler that is invoked when
            rpc_backend.construct_rpc_backend_options(**dict) is called.
        init_backend_handler (function): Handler that is invoked when the
            `_init_rpc_backend()` function is called with a backend.
             This returns the agent.
    global BackendType
    if backend_registered(backend_name):
        raise RuntimeError("RPC backend {}: already registered".format(backend_name))
    # Create a new enum type, `BackendType`, with extended members.
    existing_enum_dict = { member.value for member in BackendType}
    extended_enum_dict = dict(
            backend_name: BackendValue(
    # Can't handle Function Enum API (mypy bug #9079)
    BackendType = enum.Enum(value="BackendType", names=extended_enum_dict)  # type: ignore[misc]
    # Unable to assign a function a method (mypy bug #2427)
    BackendType.__repr__ = _backend_type_repr  # type: ignore[assignment]
    if BackendType.__doc__:
        BackendType.__doc__ = _backend_type_doc
    return BackendType[backend_name]

def construct_rpc_backend_options(

    return backend.value.construct_rpc_backend_options_handler(
        rpc_timeout, init_method, **kwargs

def init_backend(backend, *args, **kwargs):
    return backend.value.init_backend_handler(*args, **kwargs)

def _init_process_group(store, rank, world_size):
    # Initialize ProcessGroup.
    process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT

    # We're using a bunch of private APIs here since `new_group` requires the
    # default group to be initialized.
    group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout)

    assert group is not None, "Failed to initialize default ProcessGroup."

    if (rank != -1) and (rank != group.rank()):
        raise RuntimeError(
            "rank argument {} doesn't match pg rank {}".format(rank, group.rank())
    if (world_size != -1) and (world_size != group.size()):
        raise RuntimeError(
            "world_size argument {} doesn't match pg size {}".format(
                world_size, group.size()
    return group

def _tensorpipe_construct_rpc_backend_options_handler(
    from . import TensorPipeRpcBackendOptions

    return TensorPipeRpcBackendOptions(

def _tensorpipe_validate_devices(devices, device_count):
    return all(
        d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count)
        for d in devices

# detect if any worker has invalid device_map configurations, and return
# reverse device maps
def _tensorpipe_exchange_and_check_all_device_maps(
    my_name, my_device_count, my_device_maps, my_devices, group
    gathered: List[Tuple[
        str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]
    ]] = [("", 0, {}, []) for _ in range(group.size())]
        gathered, (my_name, my_device_count, my_device_maps, my_devices), group
    all_names = [name for name, _, _, _ in gathered]
    all_device_counts = {name: count for name, count, _, _ in gathered}
    all_device_maps = {name: map_ for name, _, map_, _ in gathered}
    all_devices = {name: devices for name, _, _, devices in gathered}

    _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices)

    # passed all checked, construct reverse mapping and get list of devices handled by this agent
    reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
    my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps)
    return reverse_device_maps, my_devices

def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True):
    for node in all_names:
        devices = all_devices[node]
        if len(set(devices)) != len(devices):
            raise ValueError(
                f"Node {node} has duplicated devices\n"
                f"devices = {devices}"
        if not _tensorpipe_validate_devices(devices, all_device_counts[node]):
            raise ValueError(
                f"Node {node} has devices with invalid indices\n"
                f"devices = {devices}\n"
                f"device count = {all_device_counts[node]}"

    for source_node in all_names:
        # For dynamic group (non-static) do not check the target node name since it may not have joined yet
        if is_static_group and not set(all_device_maps[source_node].keys()).issubset(all_names):
            raise ValueError(
                f"Node {source_node} has invalid target node names in its device maps\n"
                f"device maps = {all_device_maps[source_node].keys()}\n"
                f"node names = {all_names}"
        for target_node, map_ in all_device_maps[source_node].items():
            if len(set(map_.values())) != len(map_):
                raise ValueError(
                    f"Node {source_node} has duplicated target devices "
                    f"in its device map for {target_node}\n"
                    f"device map = {map_}"
            if all_devices[source_node]:
                if not set(map_.keys()).issubset(all_devices[source_node]):
                    raise ValueError(
                        f"Node {source_node} has unexpected source devices "
                        f"in its device map for {target_node}\n"
                        f"device map = {map_}\n"
                        f"devices = {all_devices[source_node]}"
            elif not _tensorpipe_validate_devices(
                map_.keys(), all_device_counts[source_node]
                raise ValueError(
                    f"Node {source_node} has source devices with invalid indices "
                    f"in its device map for {target_node}\n"
                    f"device map = {map_}\n"
                    f"device count = {all_device_counts[source_node]}"
            if all_devices.get(target_node, []):
                if not set(map_.values()).issubset(all_devices[target_node]):
                    raise ValueError(
                        f"Node {source_node} has unexpected target devices "
                        f"in its device map for {target_node}\n"
                        f"device map = {map_}\n"
                        f"devices = {all_devices[target_node]}"
            elif target_node in all_device_counts and not _tensorpipe_validate_devices(
                map_.values(), all_device_counts[target_node]
                raise ValueError(
                    f"Node {source_node} has target devices with invalid indices "
                    f"in its device map for {target_node}\n"
                    f"device map = {map_}\n"
                    f"device count = {all_device_counts[target_node]}"

def _create_device_list(my_devices, my_device_maps, reverse_device_maps):
    if not my_devices:
        devices_set: Set[torch.device] = set()
        for _, map_ in my_device_maps.items():
        for _, map_ in reverse_device_maps.items():
        my_devices = list(devices_set)
    my_devices = sorted(my_devices, key=lambda d: d.index)
    return my_devices

def _create_reverse_mapping(my_name, all_names, all_device_maps):
    reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {}
    for node in all_names:
        if my_name in all_device_maps[node]:
            reverse_device_maps[node] = {
                v: k for k, v in all_device_maps[node][my_name].items()
    return reverse_device_maps

def _get_device_infos():
    from . import TensorPipeAgent
    agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
    opts = agent._get_backend_options()
    device_count = torch.cuda.device_count()
    return device_count, opts.device_maps, opts.devices

def _set_devices_and_reverse_device_map(agent):
    from . import TensorPipeAgent
    agent = cast(TensorPipeAgent, agent)
    # Group state is retrieved from local agent
    # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid
    my_worker_info = agent.get_worker_info()
    my_name =
    all_worker_infos = agent.get_worker_infos()
    # One round to get device_maps of all workers and construct reverse device maps
    all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, []
    for worker_info in all_worker_infos:
        worker_name =
        if worker_name != my_name:
            # TODO: make async?
            device_count, device_map, devices = api.rpc_sync(worker_name, _get_device_infos)
            opts = agent._get_backend_options()
            device_count, device_map, devices = torch.cuda.device_count(), opts.device_maps, opts.devices
        all_device_counts[worker_name] = device_count
        all_device_maps[worker_name] = device_map
        all_devices[worker_name] = devices

    _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=False)
    reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)

    # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps
    for worker_name in all_names:
        # Set device list for each worker
        all_devices[worker_name] = _create_device_list(all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps)
        api.rpc_sync(worker_name, _update_group_membership,
                     args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True))

def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_options):
    from . import TensorPipeAgent
    from . import TensorPipeRpcBackendOptions
    if not isinstance(store, dist.Store):
        raise TypeError("`store` must be a c10d::Store. {}".format(store))

    if not isinstance(
        rpc_backend_options, TensorPipeRpcBackendOptions
        raise TypeError(
            "`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {}".format(

    if torch.cuda.is_available():
        # It's necessary to initialize PyTorch CUDA states here (e.g.,
        # CUDACachingAllocator). If this is missing, we could hit errors like
        # "allocator not initialized", because other processes might send
        # CUDA-related RPC request to this process before user code in this
        # process initializes its PyTorch CUDA states.
        device_count = torch.cuda.device_count()
        device_count = 0

    is_static_group = True if world_size else False
    # world_size is specified so this is a static group (ranks cannot join and leave)
    if is_static_group:
        # The agent's join method is required to behave like a barrier and perform
        # collective operations, for which it relies on a process group, instead of
        # re-implementing this on top of RPCs.
        group = _init_process_group(store, rank, world_size)

        reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps(

        # TODO: add try-except and destroy _agent in all processes if any fails.
        agent = TensorPipeAgent(


        # Run one dummy round of RPC to initialize channels/transports. Without
        # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC
        # on that process before rpc.shutdown(), as the agent initialization can
        # take longer than 5s.
        api._all_gather(None, timeout=rpc_backend_options.rpc_timeout)
        # Need a barrier here to make sure no peers leave before the rank0 finishes
        # _all_gather

        return agent
    # initialization for dynamic rpc (ranks can join and leave)
        with _group_membership_management(store, name, True):
            # Construct TPAgent with empty reverse_device_map and devices
            # these properties will be updated after initialization
            agent = TensorPipeAgent(

                # Notify all workers in group this rank has joined and set devices and reverse_device_map
                # This is a synchronous operation that completes once all existing ranks are updated
            except Exception:
            return agent



Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources