Shortcuts

Source code for torch_xla.distributed.xla_multiprocessing

import torch.multiprocessing
from torch_xla import runtime as xr
from torch_xla._internal import pjrt


[docs]@xr.requires_pjrt def spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn'): """Enables multi processing based replication. Args: fn (callable): The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in `args`. args (tuple): The arguments for `fn`. Default: Empty tuple nprocs (int): The number of processes/devices for the replication. At the moment, if specified, can be either 1 or the maximum number of devices. join (bool): Whether the call should block waiting for the completion of the processes which have being spawned. Default: True daemon (bool): Whether the processes being spawned should have the `daemon` flag set (see Python multi-processing API). Default: False start_method (string): The Python `multiprocessing` process creation method. Default: `spawn` Returns: The same object returned by the `torch.multiprocessing.spawn` API. If `nprocs` is 1 the `fn` function will be called directly, and the API will return None. """ return pjrt.spawn(fn, nprocs, start_method, args)
class MpModelWrapper(object): """Wraps a model to minimize host memory usage when `fork` method is used. This class should be used together with the `spawn(..., start_method='fork')` API to minimize the use of host memory. Instead of creating models on each multiprocessing process, hence replicating the model's initial host memory, the model is created once at global scope, and then moved into each device inside the `spawn()` target function. Example:: WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork()) def _mp_fn(index, ...): device = xm.xla_device() model = WRAPPED_MODEL.to(device) ... xmp.spawn(_mp_fn, ..., start_method='fork') This method has two advantages. First it uses only one copy of the memory pages to host the original model weights, and second it serializes the move of the wrapped model into each device, by lowering the load onto the system memory during the process. """ def __init__(self, model): """Creates a new `MpModelWrapper` object. Args: model (torch.nn.Module): The model to be wrapped. Should be on PyTorch CPU device (which is the default when creating new models). """ self._model = model self._lock = torch.multiprocessing.Lock() def to(self, device): """Retrieves the model moved onto the specified device. Args: device (torch.device): The device where the model should be moved onto. Returns: The model on the specified device. """ with self._lock: self._model.to(device) return self._model class MpSerialExecutor(object): """Utility to run a function in a serialized fashion among multi-core processes. Example:: # At global scope. SERIAL_EXEC = xmp.MpSerialExecutor() def load_dataset(path): return maybe_download_and_load(path) def _mp_fn(index, ...): # Avoid all cores downloading the same data with the serial executor. dataset = SERIAL_EXEC.run(lambda: load_dataset('/tmp/mnist-data')) ... xmp.spawn(_mp_fn, ...) """ def __init__(self): self._lock = torch.multiprocessing.Lock() def run(self, fn): """Runs the provided function serialized WRT each per-core process. Args: fn (callable): The function to run in a serialized fashion. Returns: The `fn` return value. """ with self._lock: return fn()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources