Shortcuts

Source code for torch_xla.distributed.xla_multiprocessing

from __future__ import division
from __future__ import print_function

import collections
import contextlib
import os
import socket
import sys
import torch.multiprocessing
import torch_xla
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import traceback

TpuConfigEntry = collections.namedtuple('TpuConfigEntry',
                                        'worker_name ordinal host_port')


def _find_free_tcp_port():
  with contextlib.closing(socket.socket(socket.AF_INET,
                                        socket.SOCK_STREAM)) as s:
    s.bind(('', 0))
    return s.getsockname()[1]


def _is_tpu_config():
  for env in [xenv.TPU_CONFIG, xenv.LOCAL_WORKER]:
    if os.environ.get(env, None) is not None:
      return True
  return False


def _parse_tpu_config(config):
  # XRT_TPU_CONFIG='tpu_worker;0;ismz9:25822'
  parsed = []
  for worker in config.split('|'):
    parts = worker.split(';')
    if len(parts) != 3:
      raise ValueError('Bad worker syntax: {}'.format(worker))
    parsed.append(
        TpuConfigEntry(
            worker_name=parts[0], ordinal=int(parts[1]), host_port=parts[2]))
  return parsed


def _get_devices_per_worker():
  return int(os.environ.get(xenv.TPU_NUM_DEVICES, '8'))


def _get_multiprocessing_device():
  return os.environ.get(xenv.MP_DEVICE, None)


def _get_local_worker_index():
  worker = os.environ.get(xenv.LOCAL_WORKER, None)
  return int(worker.split(':')[1]) if worker is not None else 0


def _local_index_to_global(index):
  return _get_local_worker_index() * _get_devices_per_worker() + index


def _setup_world_size(num_devices):
  # We cannot call into xla_model code at this point, as we do not know whether
  # the called code would trigger XLA library initializations (which we must
  # not do at this point). So we avoid calling into xm.xrt_world_size().
  world_size = int(os.environ.get(xenv.WORLD_SIZE, '1')) * num_devices
  os.environ[xenv.WORLD_SIZE] = str(world_size)


def _pre_fork_setup(num_devices):
  if num_devices is None:
    num_devices = _get_devices_per_worker()
  elif num_devices not in [1, _get_devices_per_worker()]:
    raise ValueError(
        'The number of devices must be either 1 or {}, got {} instead'.format(
            _get_devices_per_worker(), num_devices))
  if not os.environ.get(xenv.SERVICE_ADDRESS, None):
    # In multi-processing mode, even if there is only one TPU host, we still
    # bring up the mesh service.
    os.environ[xenv.SERVICE_ADDRESS] = 'localhost:{}'.format(
        _find_free_tcp_port())
  return num_devices


def _prepare_env_for_index(index, num_devices):
  _setup_world_size(num_devices)
  gindex = _local_index_to_global(index)
  os.environ[xenv.MP_DEVICE] = 'TPU:{}'.format(gindex)
  os.environ[xenv.ORDINAL] = str(gindex)
  os.environ[xenv.LOCAL_ORDINAL] = str(index)
  if xenv.LOCAL_WORKER not in os.environ:
    # The local worker can be missing for a 1 TPU host setup. Make sure we
    # always have one.
    tpu_config = _parse_tpu_config(os.environ[xenv.TPU_CONFIG])
    worker = tpu_config[0]
    os.environ[xenv.LOCAL_WORKER] = '{}:{}'.format(worker.worker_name,
                                                   worker.ordinal)
  if gindex > 0 and xenv.TPU_CONFIG in os.environ:
    # In multi-processing mode, only the process handling the first device of
    # the master worker, will do TPU mesh initialization.
    del os.environ[xenv.TPU_CONFIG]
  return gindex


def _setup_replication():
  if xm.xrt_world_size() > 1:
    device = xm.xla_device()
    xm.set_replication(str(device), [str(device)])


def _start_fn(index, num_devices, fn, args):
  gindex = _prepare_env_for_index(index, num_devices)
  # Calling _setup_replication() will trigger XLA library initialization, so the
  # environment must be fully setup before doing so.
  _setup_replication()
  exit_code = 0
  try:
    fn(gindex, *args)
  except Exception as e:
    print(
        'Exception in device={}: {}'.format(_get_multiprocessing_device(),
                                            str(e)),
        file=sys.stderr)
    traceback.print_exc(limit=16, file=sys.stderr)
    exit_code = 17
  sys.exit(exit_code)


[docs]def spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn'): """Enables multi processing based replication. Args: fn (callable): The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in `args`. args (tuple): The arguments for `fn`. Default: Empty tuple nprocs (int): The number of processes/devices for the replication. At the moment, if specified, can be either 1 or the maximum number of devices. join (bool): Whether the call should block waiting for the completion of the processes which have being spawned. Default: True daemon (bool): Whether the processes being spawned should have the `daemon` flag set (see Python multi-processing API). Default: False start_method (string): The Python `multiprocessing` process creation mathod. Default: `spawn` Returns: The same object returned by the `torch.multiprocessing.spawn` API. If `nprocs` is 1 the `fn` function will be called directly, and the API will not return. """ if not _is_tpu_config(): # If this is not an TPU setup, jump to normal multi-processing. nprocs = nprocs or 1 if nprocs == 1: fn(0, *args) sys.exit(0) else: return torch.multiprocessing.spawn( fn, args=args, nprocs=nprocs, join=join, daemon=daemon) nprocs = _pre_fork_setup(nprocs) if nprocs == 1: _start_fn(0, nprocs, fn, args) else: return torch.multiprocessing.start_processes( _start_fn, args=(nprocs, fn, args), nprocs=nprocs, join=join, daemon=daemon, start_method=start_method)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources