Shortcuts

PyTorch on XLA Devices

PyTorch runs on XLA devices, like TPUs, with the torch_xla package. This document describes how to run your models on these devices.

Creating an XLA Tensor

PyTorch/XLA adds a new xla device type to PyTorch. This device type works just like other PyTorch device types. For example, here’s how to create and print an XLA tensor:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing torch_xla initializes PyTorch/XLA, and xm.xla_device() returns the current XLA device. This may be a CPU or TPU depending on your environment.

XLA Tensors are PyTorch Tensors

PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors.

For example, XLA tensors can be added together:

t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)

Or matrix multiplied:

print(t0.mm(t1))

Or used with neural network modules:

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)

Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor

will throw an error since the torch.nn.Linear module is on the CPU.

Running Models on XLA Devices

Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code. The following snippets highlight these lines when running on a single device and multiple devices with XLA multi-processing.

Running on a Single XLA Device

The following snippet shows a network training on a single XLA device:

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  optimizer.zero_grad()
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()

  optimizer.step()
  xm.mark_step()

This snippet highlights how easy it is to switch your model to run on XLA. The model definition, dataloader, optimizer and training loop can work on any device. The only XLA-specific code is a couple lines that acquire the XLA device and mark the step. Calling xm.mark_step() at the end of each training iteration causes XLA to execute its current graph and update the model’s parameters. See XLA Tensor Deep Dive for more on how XLA creates graphs and runs operations.

Running on Multiple XLA Devices with Multi-processing

PyTorch/XLA makes it easy to accelerate training by running on multiple XLA devices. The following snippet shows how:

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
  device = xm.xla_device()
  mp_device_loader = pl.MpDeviceLoader(train_loader, device)

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in mp_device_loader:
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

There are three differences between this multi-device snippet and the previous single device snippet. Let’s go over then one by one.

  • xmp.spawn()

    • Creates the processes that each run an XLA device.

    • Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device.

    • Note that if you print the xm.xla_device() on each process you will see xla:0 on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only execution is with PJRT runtime on TPU v2 and TPU v3 since there will be #devices/2 processes and each process will have 2 threads(check this doc for more details).

  • MpDeviceLoader

    • Loads the training data onto each device.

    • MpDeviceLoader can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance.

    • MpDeviceLoader also call xm.mark_step for you every batches_per_execution(default to 1) batch being yield.

  • xm.optimizer_step(optimizer)

    • Consolidates the gradients between devices and issues the XLA device step computation.

    • It is pretty much a all_reduce_gradients + optimizer.step() + mark_step and returns the loss being reduced.

The model definition, optimizer definition and training loop remain the same.

NOTE: It is important to note that, when using multi-processing, the user can start retrieving and accessing XLA devices only from within the target function of xmp.spawn() (or any function which has xmp.spawn() as parent in the call stack).

See the full multiprocessing example for more on training a network on multiple XLA devices with multi-processing.

Running on TPU Pods

Multi-host setup for different accelerators can be very different. This doc will talk about the device independent bits of multi-host training and will use the TPU + PJRT runtime(currently available on 1.13 and 2.x releases) as an example.

Before you being, please take a look at our user guide at here which will explain some Google Cloud basis like how to use gcloud command and how to setup your project. You can also check here for all Cloud TPU Howto. This doc will focus on the PyTorch/XLA perspective of the Setup.

Let’s assume you have the above mnist example from above section in a train_mnist_xla.py. If it is a single host multi device training, you would ssh to the TPUVM and run command like

PJRT_DEVICE=TPU python3 train_mnist_xla.py

Now in order to run the same models on a TPU v4-16 (which has 2 host, each with 4 TPU devices), you will need to

  • Make sure each host can access the training script and training data. This is usually done by using the gcloud scp command or gcloud ssh command to copy the training scripts to all hosts.

  • Run the same training command on all hosts at the same time.

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"

Above gcloud ssh command will ssh to all hosts in TPUVM Pod and run the same command at the same time..

NOTE: You need to run run above gcloud command outside of the TPUVM vm.

The model code and training script is the same for the multi-process training and the multi-host training. PyTorch/XLA and the underlying infrastructure will make sure each device is aware of the global topology and each device’s local and global ordinal. Cross-device communication will happen across all devices instead of local devices.

For more details regarding PJRT runtime and how to run it on pod, please refer to this doc. For more information about PyTorch/XLA and TPU pod and a complete guide to run a resnet50 with fakedata on TPU pod, please refer to this guide.

XLA Tensor Deep Dive

Using XLA tensors and devices requires changing only a few lines of code. But even though XLA tensors act a lot like CPU and CUDA tensors, their internals are different. This section describes what makes XLA tensors unique.

XLA Tensors are Lazy

CPU and CUDA tensors launch operations immediately or eagerly. XLA tensors, on the other hand, are lazy. They record operations in a graph until the results are needed. Deferring execution like this lets XLA optimize it. A graph of multiple separate operations might be fused into a single optimized operation, for example.

Lazy execution is generally invisible to the caller. PyTorch/XLA automatically constructs the graphs, sends them to XLA devices, and synchronizes when copying data between an XLA device and the CPU. Inserting a barrier when taking an optimizer step explicitly synchronizes the CPU and the XLA device. For more information about our lazy tensor design, you can read this paper.

Memory Layout

The internal data representation of XLA tensors is opaque to the user. They do not expose their storage and they always appear to be contiguous, unlike CPU and CUDA tensors. This allows XLA to adjust a tensor’s memory layout for better performance.

Moving XLA Tensors to and from the CPU

XLA tensors can be moved from the CPU to an XLA device and from an XLA device to the CPU. If a view is moved then the data its viewing is also copied to the other device and the view relationship is not preserved. Put another way, once data is copied to another device it has no relationship with its previous device or any tensors on it. Again, depending on how your code operates, appreciating and accommodating this transition can be important.

Saving and Loading XLA Tensors

XLA tensors should be moved to the CPU before saving, as in the following snippet:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)

This lets you put the loaded tensors on any available device, not just the one on which they were initialized.

Per the above note on moving XLA tensors to the CPU, care must be taken when working with views. Instead of saving views it is recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).

A utility API is provided to save data by taking care of previously moving it to CPU:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

xm.save(model.state_dict(), path)

In case of multiple devices, the above API will only save the data for the master device ordinal (0).

In case where memory is limited compared to the size of the model parameters, an API is provided that reduces the memory footprint on the host:

import torch_xla.utils.serialization as xser

xser.save(model.state_dict(), path)

This API streams XLA tensors to CPU one at a time, reducing the amount of host memory used, but it requires a matching load API to restore:

import torch_xla.utils.serialization as xser

state_dict = xser.load(path)
model.load_state_dict(state_dict)

Directly saving XLA tensors is possible but not recommended. XLA tensors are always loaded back to the device they were saved from, and if that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch, is under active development and this behavior may change in the future.

Compilation Caching

The XLA compiler converts the traced HLO into an executable which runs on the devices. Compilation can be time consuming, and in cases where the HLO doesn’t change across executions, the compilation result can be persisted to disk for reuse, significantly reducing development iteration time.

Note that if the HLO changes between executions, a recompilation will still occur.

This is currently an experimental opt-in API, which must be activated before any computations are executed. Initialization is done through the initialize_cache API:

import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)

This will initialize a persistent compilation cache at the specified path. The readonly parameter can be used to control whether the worker will be able to write to the cache, which can be useful when a shared cache mount is used for an SPMD workload.

Further Reading

Additional documentation is available at the PyTorch/XLA repo. More examples of running networks on TPUs are available here.

PyTorch/XLA API

torch_xla

torch_xla.device(index: Optional[int] = None) device[source]

Returns a given instance of an XLA device.

If SPMD enables, returns a virtual device that wraps all devices available to this process.

Parameters

index – index of the XLA device to be returned. Corresponds to index in torch_xla.devices().

Returns

An XLA torch.device.

torch_xla.devices() List[device][source]

Returns all devices available in the current process.

Returns

A list of XLA torch.devices.

torch_xla.device_count() int[source]

Returns number of addressable devices in the current process.

torch_xla.sync()[source]

Launches all pending graph operations.

torch_xla.step(f: Optional[Callable] = None)[source]

Wraps code that should be dispatched to the runtime.

Experimental: xla.step is still a work in progress. Some code that currently works with xla.step but does not follow best practices will become errors in future releases. See https://github.com/pytorch/xla/issues/6751 for context.

torch_xla.manual_seed(seed, device=None)[source]

Set the seed for generating random numbers for the current XLA device.

Parameters
  • seed (python:integer) – The state to be set.

  • device (torch.device, optional) – The device where the RNG state needs to be set. If missing the default device seed will be set.

runtime

torch_xla.runtime.device_type() Optional[str][source]

Returns the current PjRt device type.

Selects a default device if none has been configured

torch_xla.runtime.local_process_count() int[source]

Returns the number of processes running on this host.

torch_xla.runtime.local_device_count() int[source]

Returns the total number of devices on this host.

Assumes each process has the same number of addressable devices.

torch_xla.runtime.addressable_device_count() int[source]

Returns the number of devices visible to this process.

torch_xla.runtime.global_device_count() int[source]

Returns the total number of devices across all processes/hosts.

torch_xla.runtime.global_runtime_device_count() int[source]

Returns the total number of runtime devices across all processes/hosts, especially useful for SPMD.

torch_xla.runtime.world_size() int[source]

Returns the total number of processes participating in the job.

torch_xla.runtime.global_ordinal() int[source]

Returns global ordinal of this thread within all processes.

Global ordinal is in range [0, global_device_count). Global ordinals are not guaranteed to have any predictable relationship to the TPU worker ID nor are they guaranteed to be contiguous on each host.

torch_xla.runtime.local_ordinal() int[source]

Returns local ordinal of this thread within this host.

Local ordinal is in range [0, local_device_count).

torch_xla.runtime.get_master_ip() str[source]

Retrieve the master worker IP for the runtime. This calls into backend-specific discovery APIs.

Returns master worker’s IP address as a string.

torch_xla.runtime.use_spmd(auto: Optional[bool] = False)[source]
torch_xla.runtime.is_spmd()[source]

Returns if SPMD is set for execution.

torch_xla.runtime.initialize_cache(path: str, readonly: bool = False)[source]

Initializes the persistent compilation cache. This API must be called before any computations have been performed.

Parameters
  • path – The path at which to store the persistent cache.

  • readonly – Whether or not this worker should have write access to the cache.

xla_model

torch_xla.core.xla_model.xla_device(n=None, devkind=None)[source]

Returns a given instance of an XLA device.

Parameters
  • n (python:int, optional) – The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of devkind will be returned.

  • devkind (string..., optional) – If specified, device type such as TPU, CUDA, CPU, or custom PJRT device. Deprecated.

Returns

A torch.device with the requested instance.

torch_xla.core.xla_model.xla_device_hw(device)[source]

Returns the hardware type of the given device.

Parameters

device (string or torch.device) – The xla device that will be mapped to the real device.

Returns

A string representation of the hardware type of the given device.

torch_xla.core.xla_model.is_master_ordinal(local=True)[source]

Checks whether the current process is the master ordinal (0).

Parameters

local (bool) – Whether the local or global master ordinal should be checked. In case of multi-host replication, there is only one global master ordinal (host 0, device 0), while there are NUM_HOSTS local master ordinals. Default: True

Returns

A boolean indicating whether the current process is the master ordinal.

torch_xla.core.xla_model.all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True)[source]

Performs an inplace reduce operation on the input tensor(s).

Parameters
  • reduce_type (string) – One of xm.REDUCE_SUM, xm.REDUCE_MUL, xm.REDUCE_AND, xm.REDUCE_OR, xm.REDUCE_MIN and xm.REDUCE_MAX.

  • inputs – Either a single torch.Tensor or a list of torch.Tensor to perform the all reduce op to.

  • scale (python:float) – A default scaling value to be applied after the reduce. Default: 1.0

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.

Returns

If a single torch.Tensor is passed, the return value is a torch.Tensor holding the reduced value (across the replicas). If a list/tuple is passed, this function performs an inplace all-reduce op on the input tensors, and returns the list/tuple itself.

torch_xla.core.xla_model.all_gather(value, dim=0, groups=None, output=None, pin_layout=True)[source]

Performs an all-gather operation along a given dimension.

Parameters
  • value (torch.Tensor) – The input tensor.

  • dim (python:int) – The gather dimension. Default: 0

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_gather() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • output (torch.Tensor) – Optional output tensor.

  • pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.

Returns

A tensor which has, in the dim dimension, all the values from the participating replicas.

torch_xla.core.xla_model.all_to_all(value, split_dimension, concat_dimension, split_count, groups=None, pin_layout=True)[source]

Performs an XLA AllToAll() operation on the input tensor.

See: https://www.tensorflow.org/xla/operation_semantics#alltoall

Parameters
  • value (torch.Tensor) – The input tensor.

  • split_dimension (python:int) – The dimension upon which the split should happen.

  • concat_dimension (python:int) – The dimension upon which the concat should happen.

  • split_count (python:int) – The split count.

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.

Returns

The result torch.Tensor of the all_to_all() operation.

torch_xla.core.xla_model.add_step_closure(closure, args=(), run_async=False)[source]

Adds a closure to the list of the ones to be run at the end of the step.

Many times during model training there is the need to print/report (print to console, post to tensorboard, etc…) information which require the content of intermediary tensors to be inspected. Inspecting different tensors content in different points of the model code requires many executions and typically causes performance issues. Adding a step closure will ensure that it will be run after the barrier, when all the live tensors will be already materialized to device data. Live tensors which will include the ones captured by the closure arguments. So using add_step_closure() will ensure a single execution will be performed, even when multiple closures are queued, requiring multiple tensors to be inspected. Step closures will be run sequentially in the order they have been queued. Note that even though using this API the execution will be optimized, it is advised to throttle the printing/reporting events once every N steps.

Parameters
  • closure (callable) – The function to be called.

  • args (tuple) – The arguments to be passed to the closure.

  • run_async – If True, run the closure asynchronously.

torch_xla.core.xla_model.wait_device_ops(devices=[])[source]

Waits for all the async operations on the given devices to complete.

Parameters

devices (string..., optional) – The devices whose async ops need to be waited for. If empty, all the local devices will be waited for.

torch_xla.core.xla_model.optimizer_step(optimizer, barrier=False, optimizer_args={}, groups=None, pin_layout=True)[source]

Run the provided optimizer step and issue the XLA device step computation.

Parameters
  • optimizer (torch.Optimizer) – The torch.Optimizer instance whose step() function needs to be called. The step() function will be called with the optimizer_args named arguments.

  • barrier (bool, optional) – Whether the XLA tensor barrier should be issued in this API. If using the PyTorch XLA ParallelLoader or DataParallel support, this is not necessary as the barrier will be issued by the XLA data loader iterator next() call. Default: False

  • optimizer_args (dict, optional) – Named arguments dictionary for the optimizer.step() call.

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – whether to pin the layout when reducing gradients. See xm.all_reduce for details.

Returns

The same value returned by the optimizer.step() call.

torch_xla.core.xla_model.save(data, file_or_path, master_only=True, global_master=False)[source]

Saves the input data into a file.

The saved data is transferred to PyTorch CPU device before being saved, so a following torch.load() will load CPU data. Care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).

Parameters
  • data – The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, …).

  • file_or_path – The destination for the data saving operation. Either a file path or a Python file object. If master_only is False the path or file objects must point to different destinations as otherwise all the writes from the same host will override each other.

  • master_only (bool, optional) – Whether only the master device should save the data. If False, the file_or_path argument should be a different file or path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True

  • global_master (bool, optional) – When master_only is True this flag controls whether every host’s master (if global_master is False) saves the content, or only the global master (ordinal 0). Default: False

  • sync (bool, optional) – Whether to synchronize all replicas after saving tensors. If True, all replicas must call xm.save or the main process will hang.

torch_xla.core.xla_model.rendezvous(tag, payload=b'', replicas=[])[source]

Waits for all the mesh clients to reach the named rendezvous.

Note: PJRT does not support the XRT mesh server, so this is effectively an alias to xla_rendezvous.

Parameters
  • tag (string) – The name of the rendezvous to join.

  • payload (bytes, optional) – The payload to be sent to the rendezvous.

  • replicas (list, python:int) – The replica ordinals taking part of the rendezvous. Empty means all replicas in the mesh. Default: []

Returns

The payloads exchanged by all the other cores, with the payload of core ordinal i at position i in the returned tuple.

torch_xla.core.xla_model.mesh_reduce(tag, data, reduce_fn)[source]

Performs an out-of-graph client mesh reduction.

Parameters
  • tag (string) – The name of the rendezvous to join.

  • data – The data to be reduced. The reduce_fn callable will receive a list with the copies of the same data coming from all the mesh client processes (one per core).

  • reduce_fn (callable) – A function which receives a list of data-like objects and returns the reduced result.

Returns

The reduced value.

torch_xla.core.xla_model.set_rng_state(seed, device=None)[source]

Sets the random number generator state.

Parameters
  • seed (python:integer) – The state to be set.

  • device (string, optional) – The device where the RNG state needs to be set. If missing the default device seed will be set.

torch_xla.core.xla_model.get_rng_state(device=None)[source]

Gets the current running random number generator state.

Parameters

device (string, optional) – The device whose RNG state needs to be retrieved. If missing the default device seed will be set.

Returns

The RNG state, as integer.

torch_xla.core.xla_model.get_memory_info(device: Optional[device] = None) MemoryInfo[source]

Retrieves the device memory usage.

Parameters
  • device – Optional[torch.device] The device whose memory information are requested.

  • device. (If not passed will use the default) –

Returns

MemoryInfo dict with memory usage for the given device.

torch_xla.core.xla_model.get_stablehlo(tensors=None) str[source]

Get StableHLO for the computation graph in string format.

If tensors is not empty, the graph with tensors as outputs will be dump. If tensors is empty, the whole computation graph will be dump. TODO(lsy323): When tensors is empty, the some intermediate tensors will also be dump as outputs. Need further investigation.

For inference graph, it is recommended to pass the model outputs to tensors. For training graph, it is not straightforward to identify the “outputs”. Using empty tensors is recommended.

To enable source line info in StableHLO, please set env var XLA_HLO_DEBUG=1.

Parameters

tensors (list[torch.Tensor], optional) – Tensors that represent the output/root of the StableHLO graph.

Returns

StableHLO Module in string format.

torch_xla.core.xla_model.get_stablehlo_bytecode(tensors=None) bytes[source]

Get StableHLO for the computation graph in bytecode format.

If tensors is not empty, the graph with tensors as outputs will be dump. If tensors is empty, the whole computation graph will be dump. TODO(lsy323): When tensors is empty, the some intermediate tensors will also be dump as outputs. Need further investigation.

For inference graph, it is recommended to pass the model outputs to tensors. For training graph, it is not straightforward to identify the “outputs”. Using empty tensors is recommended.

Parameters

tensors (list[torch.Tensor], optional) – Tensors that represent the output/root of the StableHLO graph.

Returns

StableHLO Module in bytecode format.

distributed

class torch_xla.distributed.parallel_loader.ParallelLoader(loader, devices, batchdim=0, batches_per_execution=1, loader_prefetch_size=8, device_prefetch_size=4, host_to_device_transfer_threads=1, input_sharding=None)[source]

Wraps an existing PyTorch DataLoader with background data upload.

Parameters
  • loader (torch.utils.data.DataLoader) – The PyTorch DataLoader to be wrapped.

  • devices (torch.device…) – The list of devices where the data has to be sent. The i-th sample returned by the loader will be sent to devices[i % len(devices)].

  • batchdim (python:int, optional) – The dimension which is holding the batch size. Default: 0

  • loader_prefetch_size (python:int, optional) – The max capacity of the queue used by the thread which is reading samples from the loader, to be processed by the worker threads which upload data to the devices. Default: 8

  • device_prefetch_size (python:int, optional) – The max size of the per-device queues, where the worker threads deposit tensors which have already been sent to devices. Default: 4

  • host_to_device_transfer_threads (python:int, optional) – The number of threads that work in parallel to transfer data from loader queue to device queue. Default: 1

  • input_sharding (ShardingSpec, optional) – Sharding spec to apply to compatible input tensors after loading. Default: None

per_device_loader(device)[source]

Retrieves the loader iterator object for the given device.

Parameters

device (torch.device) – The device whole loader is being requested.

Returns

The loader iterator object for the device. This is not a torch.utils.data.DataLoader interface, but a Python iterator which returns the same tensor data structure as returned by the wrapped torch.utils.data.DataLoader, but residing on XLA devices.

torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]

Enables multi processing based replication.

Parameters
  • fn (callable) – The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in args.

  • args (tuple) – The arguments for fn. Default: Empty tuple

  • nprocs (python:int) – The number of processes/devices for the replication. At the moment, if specified, can be either 1 or the maximum number of devices.

  • join (bool) – Whether the call should block waiting for the completion of the processes which have being spawned. Default: True

  • daemon (bool) – Whether the processes being spawned should have the daemon flag set (see Python multi-processing API). Default: False

  • start_method (string) – The Python multiprocessing process creation method. Default: spawn

Returns

The same object returned by the torch.multiprocessing.spawn API. If nprocs is 1 the fn function will be called directly, and the API will return None.

spmd

torch_xla.distributed.spmd.mark_sharding(t: Union[Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Optional[Union[Tuple, int, str]]]) XLAShardedTensor[source]

Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. :param t: input tensor to be annotated with partition_spec. :type t: Union[torch.Tensor, XLAShardedTensor] :param mesh: describes the logical XLA device topology and the underlying device IDs. :type mesh: Mesh :param partition_spec: A tuple of device_mesh dimension index or

None. Each index is an int, str if the mesh axis is named, or tuple of int or str. This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). When a tuple is specified, the corresponding input tensor axis will be sharded along all logical axes in the tuple. Note that the order the mesh axes are specified in the tuple will impact the resulting sharding.

Parameters
  • example (For) –

  • row-wise (we can shard an 8x10 tensor 4-way) –

  • column-wise. (and replicate) –

  • torch.randn (>> input =) –

  • = (>> partition_spec) –

  • =

  • dynamo_custom_op (bool) – if set to True, it calls the dynamo custom op variant of mark_sharding to make itself recognizeable and traceable by dynamo.

Example

>>> import torch_xla.runtime as xr
>>> import torch_xla.distributed.spmd as xs
>>> mesh_shape = (4, 2)
>>> num_devices = xr.global_runtime_device_count()
>>> device_ids = np.array(range(num_devices))
>>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
>>> # 4-way data parallel
>>> input = torch.randn(8, 32).to(xm.xla_device())
>>> xs.mark_sharding(input, mesh, (0, None))
>>> # 2-way model parallel
>>> linear = nn.Linear(32, 10).to(xm.xla_device())
>>> xs.mark_sharding(linear.weight, mesh, (None, 1))
torch_xla.distributed.spmd.clear_sharding(t: Union[Tensor, XLAShardedTensor]) Tensor[source]

Clear sharding annotation from the input tensor and return a cpu casted tensor. This is a in place operation but will also return the same torch.Tensor back.

Parameters

t (Union[torch.Tensor, XLAShardedTensor]) – Tensor that we want to clear the sharding

Returns

tensor that without sharding.

Return type

t (torch.Tensor)

Example

>>> import torch_xla.distributed.spmd as xs
>>> torch_xla.runtime.use_spmd()
>>> t1 = torch.randn(8,8).to(torch_xla.device())
>>> mesh = xs.get_1d_mesh()
>>> xs.mark_sharding(t1, mesh, (0, None))
>>> xs.clear_sharding(t1)
torch_xla.distributed.spmd.set_global_mesh(mesh: Mesh)[source]

Set the global mesh that can be used for the current process.

Parameters

mesh – (Mesh) Mesh object that will be the global mesh.

Example

>>> import torch_xla.distributed.spmd as xs
>>> mesh = xs.get_1d_mesh("data")
>>> xs.set_global_mesh(mesh)
torch_xla.distributed.spmd.get_global_mesh() Optional[Mesh][source]

Get the global mesh for the current process.

Returns

(Optional[Mesh]) Mesh object if global mesh is set, otherwise return None.

Return type

mesh

Example

>>> import torch_xla.distributed.spmd as xs
>>> xs.get_global_mesh()
torch_xla.distributed.spmd.get_1d_mesh(axis_name: Optional[str] = None) Mesh[source]

Helper function to return the mesh with all devices in one dimension.

Parameters

axis_name – (Optional[str]) optional string to represent the axis name of the mesh

Returns

Mesh object

Return type

Mesh

Example

>>> # This example is assuming 1 TPU v4-8
>>> import torch_xla.distributed.spmd as xs
>>> mesh = xs.get_1d_mesh("data")
>>> print(mesh.mesh_shape)
>>> (4,)
>>> print(mesh.axis_names)
>>> ('data',)
class torch_xla.distributed.spmd.Mesh(device_ids: Union[ndarray, List], mesh_shape: Tuple[int, ...], axis_names: Optional[Tuple[str, ...]] = None)[source]

Describe the logical XLA device topology mesh and the underlying resources.

Parameters
  • device_ids (Union[np.ndarray, List]) – A raveled list of devices (IDs) in a custom order. The list is reshaped to an mesh_shape array, filling the elements using C-like index order.

  • mesh_shape (Tuple[python:int, ...]) – A int tuple describing the logical topology shape of the device mesh, and each element describes the number of devices in the corresponding axis.

  • axis_names (Tuple[str, ...]) – A sequence of resource axis names to be assigned to the dimensions of the devices argument. Its length should match the rank of devices.

Example

>>> mesh_shape = (4, 2)
>>> num_devices = len(xm.get_xla_supported_devices())
>>> device_ids = np.array(range(num_devices))
>>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
>>> mesh.get_logical_mesh()
>>> array([[0, 1],
          [2, 3],
          [4, 5],
          [6, 7]])
>>> mesh.shape()
>>> OrderedDict([('x', 4), ('y', 2)])
class torch_xla.distributed.spmd.HybridMesh(*, ici_mesh_shape: Tuple[int, ...], dcn_mesh_shape: Optional[Tuple[int, ...]] = None, axis_names: Optional[Tuple[str, ...]] = None)[source]
Creates a hybrid device mesh of devices connected with ICI and DCN networks.

The shape of logical mesh should be ordered by increasing network-intensity e.g. [replica, data, model] where mdl has the most network communication requirements.

Parameters
  • ici_mesh_shape – shape of the logical mesh for inner connected devices.

  • dcn_mesh_shape – shape of logical mesh for outer connected devices.

Example

>>> # This example is assuming 2 slices of v4-8.
>>> ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
>>> dcn_mesh_shape = (2, 1, 1)
>>> mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
>>> print(mesh.shape())
>>> >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])

experimental

torch_xla.experimental.eager_mode(enable: bool)[source]

Configure torch_xla’s default executation mode.

Under eager mode only functions that was `torch_xla.compile`d will be traced and compiled. Other torch ops will be executed eagerly.

torch_xla.experimental.compile(func)[source]

Compile the func with Lazy Tensor.

Return the optimized function that takes exact same input. Compile will run the target func under the tracing mode using Lazy tensor.

debug

torch_xla.debug.metrics.metrics_report()[source]

Retrieves a string containing the full metrics and counters report.

torch_xla.debug.metrics.short_metrics_report(counter_names: Optional[list] = None, metric_names: Optional[list] = None)[source]

Retrieves a string containing the full metrics and counters report.

Parameters
  • counter_names (list) – The list of counter names whose data needs to be printed.

  • metric_names (list) – The list of metric names whose data needs to be printed.

torch_xla.debug.metrics.counter_names()[source]

Retrieves all the currently active counter names.

torch_xla.debug.metrics.counter_value(name)[source]

Returns the value of an active counter.

Parameters

name (string) – The name of the counter whose value needs to be retrieved.

Returns

The counter value as integer.

torch_xla.debug.metrics.metric_names()[source]

Retrieves all the currently active metric names.

torch_xla.debug.metrics.metric_data(name)[source]

Returns the data of an active metric.

Parameters

name (string) – The name of the metric whose data needs to be retrieved.

Returns

The metric data, which is a tuple of (TOTAL_SAMPLES, ACCUMULATOR, SAMPLES). The TOTAL_SAMPLES is the total number of samples which have been posted to the metric. A metric retains only a given number of samples (in a circular buffer). The ACCUMULATOR is the sum of the samples over TOTAL_SAMPLES. The SAMPLES is a list of (TIME, VALUE) tuples.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources