Shortcuts

Source code for torch_xla.distributed.xla_multiprocessing

from __future__ import division
from __future__ import print_function

import collections
import contextlib
import os
import re
import socket
import sys
import torch.multiprocessing
import torch_xla
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import traceback

PreForkConfig = collections.namedtuple('PreForkConfig', 'dev_kind num_devices')
WorkerConfigEntry = collections.namedtuple('WorkerConfigEntry',
                                           'worker_name ordinal host_port')

_LOCAL_WORKER = 'localservice'
_CUDA_VISIBLE_DEVICES = 'CUDA_VISIBLE_DEVICES'


def _get_free_tcp_ports(n=1):
  ports = []
  for _ in range(0, n):
    with contextlib.closing(socket.socket(socket.AF_INET,
                                          socket.SOCK_STREAM)) as s:
      s.bind(('', 0))
      ports.append(s.getsockname()[1])
  return ports


def _is_xla_config():
  for env in [xenv.TPU_CONFIG, xenv.LOCAL_WORKER, xenv.GPU_NUM_DEVICES]:
    if os.environ.get(env, None) is not None:
      return True
  return False


def _get_world_size():
  # We cannot use the xla_model.py API here, as the features used in that module
  # needs the setup provided by this one.
  return int(os.environ.get(xenv.WORLD_SIZE, '1'))


def _create_gpu_devices(num_gpus):
  devices = []
  for h in range(0, _get_world_size()):
    for i in range(0, num_gpus):
      gindex = h * num_gpus + i
      # We use CUDA_VISIBLE_DEVICES to limit the set of CUDA devices per process
      # to 1, and its device index is always 0. We use the task to disambiguate
      # TF devices.
      tfdevice = '/job:{}/replica:0/task:{}/device:XLA_GPU:0'.format(
          _LOCAL_WORKER, gindex)
      devices.append('GPU:{};{}'.format(gindex, tfdevice))
  os.environ[xenv.DEVICE_MAP] = '|'.join(devices)


def _parse_workers_config(config):
  # XRT_WORKERS='worker:0;ismz9:25822'
  workers = collections.OrderedDict()
  for worker in config.split('|'):
    m = re.match(r'(\w+):(\d+);((grpc://)?[\w.]+:\d+)', worker)
    if not m:
      raise ValueError('Bad worker syntax: {}'.format(worker))
    workers['{}:{}'.format(m.group(1), m.group(2))] = WorkerConfigEntry(
        worker_name=m.group(1), ordinal=int(m.group(2)), host_port=m.group(3))
  return workers


def _parse_tpu_config(config):
  # XRT_TPU_CONFIG='tpu_worker;0;ismz9:25822'
  workers = collections.OrderedDict()
  for worker in config.split('|'):
    m = re.match(r'(\w+);(\d+);([\w.]+:\d+)', worker)
    if not m:
      raise ValueError('Bad worker syntax: {}'.format(worker))
    workers['{}:{}'.format(m.group(1), m.group(2))] = WorkerConfigEntry(
        worker_name=m.group(1), ordinal=int(m.group(2)), host_port=m.group(3))
  return workers


def _get_devices_per_worker():
  num_tpus = os.environ.get(xenv.TPU_NUM_DEVICES, None)
  if os.environ.get(xenv.TPU_CONFIG, None) is not None or num_tpus is not None:
    return int(num_tpus or '8'), 'TPU'
  num_gpus = os.environ.get(xenv.GPU_NUM_DEVICES, None)
  if num_gpus is not None:
    return int(num_gpus), 'GPU'
  raise RuntimeError('Missing TPU or GPU configuration')


def _get_multiprocessing_device():
  return os.environ.get(xenv.MP_DEVICE, None)


def _get_local_worker_index():
  worker = os.environ.get(xenv.LOCAL_WORKER, None)
  if worker is None:
    return 0
  m = re.match(r'(\w+):(\d+)', worker)
  if not m:
    raise ValueError('Bad worker syntax: {}'.format(worker))
  return int(m.group(2))


def _local_index_to_global(index, num_devices):
  return _get_local_worker_index() * num_devices + index


def _setup_world_size(num_devices):
  # We cannot call into xla_model code at this point, as we do not know whether
  # the called code would trigger XLA library initializations (which we must
  # not do at this point). So we avoid calling into xm.xrt_world_size().
  world_size = _get_world_size() * num_devices
  os.environ[xenv.WORLD_SIZE] = str(world_size)


def _setup_workers(num_devices):
  world_size = _get_world_size()
  workers_env = os.environ.get(xenv.WORKERS, None)
  workers = []
  if workers_env is not None:
    wcfg = _parse_workers_config(workers_env)
    assert world_size == len(
        wcfg), 'World size ({}) must match the configured workers ({})'.format(
            world_size, len(wcfg))
    for h, worker in enumerate(wcfg):
      m = re.match(r'(.*):(\d+)$', worker.host_port)
      if not m:
        raise RuntimeError('Bad worker HOST:PORT format: {}'.format(
            worker.host_port))
      for i in range(0, num_gpus):
        gindex = h * num_gpus + i
        workers.append('{}:{};grpc://{}:{}'.format(worker.worker_name, gindex,
                                                   m.group(1),
                                                   int(m.group(2)) + i))
  else:
    assert world_size == 1, ('Cannot use more than one host without {} '
                             'configuration: {}').format(
                                 xenv.WORKERS, world_size)
    ports = _get_free_tcp_ports(num_devices)
    host = socket.getfqdn()
    for wid in range(0, num_devices):
      workers.append('{}:{};grpc://{}:{}'.format(_LOCAL_WORKER, wid, host,
                                                 ports[wid]))
  os.environ[xenv.WORKERS] = '|'.join(workers)


def _pre_fork_setup(num_devices):
  dev_count, dev_kind = _get_devices_per_worker()
  if num_devices is None:
    num_devices = dev_count
  elif num_devices not in [1, dev_count]:
    raise ValueError(
        'The number of devices must be either 1 or {}, got {} instead'.format(
            dev_count, num_devices))
  if num_devices > 1 and not os.environ.get(xenv.SERVICE_ADDRESS, None):
    # In multi-processing mode, even if there is only one XLA host, we still
    # bring up the mesh service.
    os.environ[xenv.SERVICE_ADDRESS] = '{}:{}'.format(socket.getfqdn(),
                                                      _get_free_tcp_ports()[0])
  if dev_kind == 'GPU':
    _setup_workers(num_devices)
    _create_gpu_devices(num_devices)
  return PreForkConfig(dev_kind=dev_kind, num_devices=num_devices)


def _setup_gpu_worker(index, gindex, pf_cfg):
  os.environ[xenv.MP_DEVICE] = 'GPU:{}'.format(gindex)
  os.environ[xenv.LOCAL_WORKER] = '{}:{}'.format(_LOCAL_WORKER, gindex)
  # Every process is restricted to 1 GPU device, which in such process will be
  # named XLA_GPU:0.
  os.environ[_CUDA_VISIBLE_DEVICES] = str(index)
  # We have expanded the GPU devices in the device map already, in
  # _create_gpu_devices(), so delete the key from the environment as it
  # otherwise triggers device generation again in computation_client.cc.
  os.environ.pop(xenv.GPU_NUM_DEVICES, None)


def _setup_tpu_worker(index, gindex, pf_cfg, tpu_env_config):
  os.environ[xenv.MP_DEVICE] = 'TPU:{}'.format(gindex)
  if xenv.LOCAL_WORKER not in os.environ:
    # The local worker can be missing for a 1 TPU host setup. Make sure we
    # always have one.
    assert tpu_env_config is not None, 'tpu_env_config must not be None'
    tpu_config = _parse_tpu_config(tpu_env_config)
    worker = list(tpu_config.values())[0]
    os.environ[xenv.LOCAL_WORKER] = '{}:{}'.format(worker.worker_name,
                                                   worker.ordinal)
  if gindex > 0:
    # In multi-processing mode, only the process handling the first device of
    # the master worker, will do TPU mesh initialization, so we need to remove
    # the environment configs which would prevent the client to be falling in
    # the mesh client config path.
    os.environ.pop(xenv.TPU_CONFIG, None)
    os.environ.pop(xenv.TPU_NUM_DEVICES, None)


def _prepare_env_for_index(index, pf_cfg):
  _setup_world_size(pf_cfg.num_devices)
  gindex = _local_index_to_global(index, pf_cfg.num_devices)
  os.environ[xenv.ORDINAL] = str(gindex)
  os.environ[xenv.LOCAL_ORDINAL] = str(index)

  if pf_cfg.dev_kind == 'TPU':
    _setup_tpu_worker(index, gindex, pf_cfg,
                      os.environ.get(xenv.TPU_CONFIG, None))
  elif pf_cfg.dev_kind == 'GPU':
    _setup_gpu_worker(index, gindex, pf_cfg)
  return gindex


def _setup_replication():
  # At this point xla_model.py APIs are allowed as the setup is already
  # completed.
  if xm.xrt_world_size() > 1:
    device = xm.xla_device()
    xm.set_replication(device, [device])


def _start_fn(index, pf_cfg, fn, args):
  gindex = _prepare_env_for_index(index, pf_cfg)
  # Calling _setup_replication() will trigger XLA library initialization, so the
  # environment must be fully setup before doing so.
  _setup_replication()
  exit_code = 0
  try:
    fn(gindex, *args)
  except Exception as e:
    print(
        'Exception in device={}: {}'.format(_get_multiprocessing_device(),
                                            str(e)),
        file=sys.stderr)
    traceback.print_exc(limit=16, file=sys.stderr)
    exit_code = 17
  sys.exit(exit_code)


def _run_direct(fn, args, nprocs, join, daemon, start_method):
  nprocs = nprocs or 1
  if nprocs == 1 and join:
    fn(0, *args)
  else:
    return torch.multiprocessing.spawn(
        fn, args=args, nprocs=nprocs, join=join, daemon=daemon)


[docs]def spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn'): """Enables multi processing based replication. Args: fn (callable): The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in `args`. args (tuple): The arguments for `fn`. Default: Empty tuple nprocs (int): The number of processes/devices for the replication. At the moment, if specified, can be either 1 or the maximum number of devices. join (bool): Whether the call should block waiting for the completion of the processes which have being spawned. Default: True daemon (bool): Whether the processes being spawned should have the `daemon` flag set (see Python multi-processing API). Default: False start_method (string): The Python `multiprocessing` process creation method. Default: `spawn` Returns: The same object returned by the `torch.multiprocessing.spawn` API. If `nprocs` is 1 the `fn` function will be called directly, and the API will not return. """ if not _is_xla_config(): # If this is not an XLA setup, jump to normal multi-processing. return _run_direct(fn, args, nprocs, join, daemon, start_method) pf_cfg = _pre_fork_setup(nprocs) if pf_cfg.num_devices == 1: _start_fn(0, pf_cfg, fn, args) else: return torch.multiprocessing.start_processes( _start_fn, args=(pf_cfg, fn, args), nprocs=pf_cfg.num_devices, join=join, daemon=daemon, start_method=start_method)
[docs]class MpModelWrapper(object): """Wraps a model to minimize host memory usage when `fork` method is used. This class should be used together with the `spawn(..., start_method='fork')` API to minimize the use of host memory. Instead of creating models on each multiprocessing process, hence replicating the model's initial host memory, the model is created once at global scope, and then moved into each device inside the `spawn()` target function. Example:: WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork()) def _mp_fn(index, ...): device = xm.xla_device() model = WRAPPED_MODEL.to(device) ... xmp.spawn(_mp_fn, ..., start_method='fork') This method has two advantages. First if uses only one copy of the memory pages to host the original model weights, and second it serializes the move of the wrapped model into each device, by lowering the load onto the system memory during the process. """ def __init__(self, model): """Creates a new `MpModelWrapper` object. Args: model (torch.nn.Module): The model to be wrapped. Should be on PyTorch CPU device (which is the default when creating new models). """ self._model = model self._lock = torch.multiprocessing.Lock()
[docs] def to(self, device): """Retrieves the model moved onto the specified device. Args: device (torch.device): The device where the model should be moved onto. Returns: The model on the specified device. """ with self._lock: self._model.to(device) return self._model
[docs]class MpSerialExecutor(object): """Utility to run a function in a serialized fashion among multi-core processes. Example:: # At global scope. SERIAL_EXEC = xmp.MpSerialExecutor() def load_dataset(path): return maybe_download_and_load(path) def _mp_fn(index, ...): # Avoid all cores downloading the same data with the serial executor. dataset = SERIAL_EXEC.run(lambda: load_dataset('/tmp/mnist-data')) ... xmp.spawn(_mp_fn, ...) """ def __init__(self): self._lock = torch.multiprocessing.Lock()
[docs] def run(self, fn): """Runs the provided function serialized WRT each per-core process. Args: fn (callable): The function to run in a serialized fashion. Returns: The `fn` return value. """ with self._lock: return fn()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources