• Docs >
  • PyTorch on XLA Devices

PyTorch on XLA Devices

PyTorch runs on XLA devices, like TPUs, with the torch_xla package. This document describes how to run your models on these devices.

Creating an XLA Tensor

PyTorch/XLA adds a new xla device type to PyTorch. This device type works just like other PyTorch device types. For example, here’s how to create and print an XLA tensor:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())

This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing torch_xla initializes PyTorch/XLA, and xm.xla_device() returns the current XLA device. This may be a CPU or TPU depending on your environment.

XLA Tensors are PyTorch Tensors

PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors.

For example, XLA tensors can be added together:

t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)

Or matrix multiplied:


Or used with neural network modules:

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)

Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
# Input tensor is not an XLA tensor: torch.FloatTensor

will throw an error since the torch.nn.Linear module is on the CPU.

Running Models on XLA Devices

Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code. The following snippets highlight these lines when running on a single device and multiple devices with XLA multi-processing.

Running on a Single XLA Device

The following snippet shows a network training on a single XLA device:

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)


This snippet highlights how easy it is to switch your model to run on XLA. The model definition, dataloader, optimizer and training loop can work on any device. The only XLA-specific code is a couple lines that acquire the XLA device and mark the step. Calling xm.mark_step() at the end of each training iteration causes XLA to execute its current graph and update the model’s parameters. See XLA Tensor Deep Dive for more on how XLA creates graphs and runs operations.

Running on Multiple XLA Devices with Multi-processing

PyTorch/XLA makes it easy to accelerate training by running on multiple XLA devices. The following snippet shows how:

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
  device = xm.xla_device()
  mp_device_loader = pl.MpDeviceLoader(train_loader, device)

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in mp_device_loader:
    output = model(data)
    loss = loss_fn(output, target)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

There are three differences between this multi-device snippet and the previous single device snippet. Let’s go over then one by one.

  • xmp.spawn()

    • Creates the processes that each run an XLA device.

    • Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device.

    • Note that if you print the xm.xla_device() on each process you will see xla:0 on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only execution is with PJRT runtime on TPU v2 and TPU v3 since there will be #devices/2 processes and each process will have 2 threads(check this doc for more details).

  • MpDeviceLoader

    • Loads the training data onto each device.

    • MpDeviceLoader can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance.

    • MpDeviceLoader also call xm.mark_step for you every batches_per_execution(default to 1) batch being yield.

  • xm.optimizer_step(optimizer)

    • Consolidates the gradients between devices and issues the XLA device step computation.

    • It is pretty much a all_reduce_gradients + optimizer.step() + mark_step and returns the loss being reduced.

The model definition, optimizer definition and training loop remain the same.

NOTE: It is important to note that, when using multi-processing, the user can start retrieving and accessing XLA devices only from within the target function of xmp.spawn() (or any function which has xmp.spawn() as parent in the call stack).

See the full multiprocessing example for more on training a network on multiple XLA devices with multi-processing.

Running on TPU Pods

Multi-host setup for different accelerators can be very different. This doc will talk about the device independent bits of multi-host training and will use the TPU + PJRT runtime(currently available on 1.13 and 2.x releases) as an example.

Before you being, please take a look at our user guide at here which will explain some Google Cloud basis like how to use gcloud command and how to setup your project. You can also check here for all Cloud TPU Howto. This doc will focus on the PyTorch/XLA perspective of the Setup.

Let’s assume you have the above mnist example from above section in a train_mnist_xla.py. If it is a single host multi device training, you would ssh to the TPUVM and run command like

PJRT_DEVICE=TPU python3 train_mnist_xla.py

Now in order to run the same models on a TPU v4-16 (which has 2 host, each with 4 TPU devices), you will need to

  • Make sure each host can access the training script and training data. This is usually done by using the gcloud scp command or gcloud ssh command to copy the training scripts to all hosts.

  • Run the same training command on all hosts at the same time.

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"

Above gcloud ssh command will ssh to all hosts in TPUVM Pod and run the same command at the same time..

NOTE: You need to run run above gcloud command outside of the TPUVM vm.

The model code and training script is the same for the multi-process training and the multi-host training. PyTorch/XLA and the underlying infrastructure will make sure each device is aware of the global topology and each device’s local and global ordinal. Cross-device communication will happen across all devices instead of local devices.

For more details regarding PJRT runtime and how to run it on pod, please refer to this doc. For more information about PyTorch/XLA and TPU pod and a complete guide to run a resnet50 with fakedata on TPU pod, please refer to this guide.

XLA Tensor Deep Dive

Using XLA tensors and devices requires changing only a few lines of code. But even though XLA tensors act a lot like CPU and CUDA tensors, their internals are different. This section describes what makes XLA tensors unique.

XLA Tensors are Lazy

CPU and CUDA tensors launch operations immediately or eagerly. XLA tensors, on the other hand, are lazy. They record operations in a graph until the results are needed. Deferring execution like this lets XLA optimize it. A graph of multiple separate operations might be fused into a single optimized operation, for example.

Lazy execution is generally invisible to the caller. PyTorch/XLA automatically constructs the graphs, sends them to XLA devices, and synchronizes when copying data between an XLA device and the CPU. Inserting a barrier when taking an optimizer step explicitly synchronizes the CPU and the XLA device. For more information about our lazy tensor design, you can read this paper.

XLA Tensors and bFloat16

PyTorch/XLA can use the bfloat16 datatype when running on TPUs. In fact, PyTorch/XLA handles float types (torch.float and torch.double) differently on TPUs. This behavior is controlled by the XLA_USE_BF16 and XLA_DOWNCAST_BF16 environment variable:

  • By default both torch.float and torch.double are torch.float on TPUs.

  • If XLA_USE_BF16 is set, then torch.float and torch.double are both bfloat16 on TPUs.

  • If XLA_DOWNCAST_BF16 is set, then torch.float is bfloat16 on TPUs and torch.double is float32 on TPUs.

  • If a PyTorch tensor has torch.bfloat16 data type, this will be directly mapped to the TPU bfloat16 (XLA BF16 primitive type).

Developers should note that XLA tensors on TPUs will always report their PyTorch datatype regardless of the actual datatype they’re using. This conversion is automatic and opaque. If an XLA tensor on a TPU is moved back to the CPU it will be converted from its actual datatype to its PyTorch datatype. Depending on how your code operates, this conversion triggered by the type of processing unit can be important.

Memory Layout

The internal data representation of XLA tensors is opaque to the user. They do not expose their storage and they always appear to be contiguous, unlike CPU and CUDA tensors. This allows XLA to adjust a tensor’s memory layout for better performance.

Moving XLA Tensors to and from the CPU

XLA tensors can be moved from the CPU to an XLA device and from an XLA device to the CPU. If a view is moved then the data its viewing is also copied to the other device and the view relationship is not preserved. Put another way, once data is copied to another device it has no relationship with its previous device or any tensors on it. Again, depending on how your code operates, appreciating and accommodating this transition can be important.

Saving and Loading XLA Tensors

XLA tensors should be moved to the CPU before saving, as in the following snippet:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)

This lets you put the loaded tensors on any available device, not just the one on which they were initialized.

Per the above note on moving XLA tensors to the CPU, care must be taken when working with views. Instead of saving views it is recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).

A utility API is provided to save data by taking care of previously moving it to CPU:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

xm.save(model.state_dict(), path)

In case of multiple devices, the above API will only save the data for the master device ordinal (0).

In case where memory is limited compared to the size of the model parameters, an API is provided that reduces the memory footprint on the host:

import torch_xla.utils.serialization as xser

xser.save(model.state_dict(), path)

This API streams XLA tensors to CPU one at a time, reducing the amount of host memory used, but it requires a matching load API to restore:

import torch_xla.utils.serialization as xser

state_dict = xser.load(path)

Directly saving XLA tensors is possible but not recommended. XLA tensors are always loaded back to the device they were saved from, and if that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch, is under active development and this behavior may change in the future.

Compilation Caching

The XLA compiler converts the traced HLO into an executable which runs on the devices. Compilation can be time consuming, and in cases where the HLO doesn’t change across executions, the compilation result can be persisted to disk for reuse, significantly reducing development iteration time.

Note that if the HLO changes between executions, a recompilation will still occur.

This is currently an experimental opt-in API, which must be activated before any computations are executed. Initialization is done through the initialize_cache API:

import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)

This will initialize a persistent compilation cache at the specified path. The readonly parameter can be used to control whether the worker will be able to write to the cache, which can be useful when a shared cache mount is used for an SPMD workload.

Further Reading

Additional documentation is available at the PyTorch/XLA repo. More examples of running networks on TPUs are available here.



torch_xla.device(index: Optional[int] = None) device[source]

Returns a given instance of an XLA device.

If SPMD enables, returns a virtual device that wraps all devices available to this process.


index – index of the XLA device to be returned. Corresponds to index in torch_xla.devices().


An XLA torch.device.

torch_xla.devices() List[device][source]

Returns all devices available in the current process.


A list of XLA torch.devices.

torch_xla.device_count() int[source]

Returns number of addressable devices in the current process.


Launches all pending graph operations.


Wraps code that should be dispatched to the runtime.

Experimental: xla.step is still a work in progress. Some code that currently works with xla.step but does not follow best practices will become errors in future releases. See https://github.com/pytorch/xla/issues/6751 for context.


torch_xla.runtime.device_type() Optional[str][source]

Returns the current PjRt device type.

Selects a default device if none has been configured

torch_xla.runtime.local_process_count() int[source]

Returns the number of processes running on this host.

torch_xla.runtime.local_device_count() int[source]

Returns the total number of devices on this host.

Assumes each process has the same number of addressable devices.

torch_xla.runtime.addressable_device_count() int[source]

Returns the number of devices visible to this process.

torch_xla.runtime.global_device_count() int[source]

Returns the total number of devices across all processes/hosts.

torch_xla.runtime.global_runtime_device_count() int[source]

Returns the total number of runtime devices across all processes/hosts, especially useful for SPMD.

torch_xla.runtime.world_size() int[source]

Returns the total number of processes participating in the job.

torch_xla.runtime.global_ordinal() int[source]

Returns global ordinal of this thread within all processes.

Global ordinal is in range [0, global_device_count). Global ordinals are not guaranteed to have any predictable relationship to the TPU worker ID nor are they guaranteed to be contiguous on each host.

torch_xla.runtime.local_ordinal() int[source]

Returns local ordinal of this thread within this host.

Local ordinal is in range [0, local_device_count).

torch_xla.runtime.get_master_ip() str[source]

Retrieve the master worker IP for the runtime. This calls into backend-specific discovery APIs.

Returns master worker’s IP address as a string.

torch_xla.runtime.use_spmd(auto: Optional[bool] = False)[source]

Returns if SPMD is set for execution.

torch_xla.runtime.initialize_cache(path: str, readonly: bool = False)[source]

Initializes the persistent compilation cache. This API must be called before any computations have been performed.

  • path – The path at which to store the persistent cache.

  • readonly – Whether or not this worker should have write access to the cache.


torch_xla.core.xla_model.xla_device(n=None, devkind=None)[source]

Returns a given instance of an XLA device.

  • n (python:int, optional) – The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of devkind will be returned.

  • devkind (string..., optional) – If specified, device type such as TPU, CUDA, CPU, or custom PJRT device. Deprecated.


A torch.device with the requested instance.


Returns the hardware type of the given device.


device (string or torch.device) – The xla device that will be mapped to the real device.


A string representation of the hardware type of the given device.


Checks whether the current process is the master ordinal (0).


local (bool) – Whether the local or global master ordinal should be checked. In case of multi-host replication, there is only one global master ordinal (host 0, device 0), while there are NUM_HOSTS local master ordinals. Default: True


A boolean indicating whether the current process is the master ordinal.

torch_xla.core.xla_model.all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True)[source]

Performs an inplace reduce operation on the input tensor(s).

  • reduce_type (string) – One of xm.REDUCE_SUM, xm.REDUCE_MUL, xm.REDUCE_AND, xm.REDUCE_OR, xm.REDUCE_MIN and xm.REDUCE_MAX.

  • inputs – Either a single torch.Tensor or a list of torch.Tensor to perform the all reduce op to.

  • scale (python:float) – A default scaling value to be applied after the reduce. Default: 1.0

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.


If a single torch.Tensor is passed, the return value is a torch.Tensor holding the reduced value (across the replicas). If a list/tuple is passed, this function performs an inplace all-reduce op on the input tensors, and returns the list/tuple itself.

torch_xla.core.xla_model.all_gather(value, dim=0, groups=None, output=None, pin_layout=True)[source]

Performs an all-gather operation along a given dimension.

  • value (torch.Tensor) – The input tensor.

  • dim (python:int) – The gather dimension. Default: 0

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_gather() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • output (torch.Tensor) – Optional output tensor.

  • pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.


A tensor which has, in the dim dimension, all the values from the participating replicas.

torch_xla.core.xla_model.all_to_all(value, split_dimension, concat_dimension, split_count, groups=None, pin_layout=True)[source]

Performs an XLA AllToAll() operation on the input tensor.

See: https://www.tensorflow.org/xla/operation_semantics#alltoall

  • value (torch.Tensor) – The input tensor.

  • split_dimension (python:int) – The dimension upon which the split should happen.

  • concat_dimension (python:int) – The dimension upon which the concat should happen.

  • split_count (python:int) – The split count.

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.


The result torch.Tensor of the all_to_all() operation.

torch_xla.core.xla_model.add_step_closure(closure, args=(), run_async=False)[source]

Adds a closure to the list of the ones to be run at the end of the step.

Many times during model training there is the need to print/report (print to console, post to tensorboard, etc…) information which require the content of intermediary tensors to be inspected. Inspecting different tensors content in different points of the model code requires many executions and typically causes performance issues. Adding a step closure will ensure that it will be run after the barrier, when all the live tensors will be already materialized to device data. Live tensors which will include the ones captured by the closure arguments. So using add_step_closure() will ensure a single execution will be performed, even when multiple closures are queued, requiring multiple tensors to be inspected. Step closures will be run sequentially in the order they have been queued. Note that even though using this API the execution will be optimized, it is advised to throttle the printing/reporting events once every N steps.

  • closure (callable) – The function to be called.

  • args (tuple) – The arguments to be passed to the closure.

  • run_async – If True, run the closure asynchronously.


Waits for all the async operations on the given devices to complete.


devices (string..., optional) – The devices whose async ops need to be waited for. If empty, all the local devices will be waited for.

torch_xla.core.xla_model.optimizer_step(optimizer, barrier=False, optimizer_args={}, groups=None, pin_layout=True)[source]

Run the provided optimizer step and issue the XLA device step computation.

  • optimizer (torch.Optimizer) – The torch.Optimizer instance whose step() function needs to be called. The step() function will be called with the optimizer_args named arguments.

  • barrier (bool, optional) – Whether the XLA tensor barrier should be issued in this API. If using the PyTorch XLA ParallelLoader or DataParallel support, this is not necessary as the barrier will be issued by the XLA data loader iterator next() call. Default: False

  • optimizer_args (dict, optional) – Named arguments dictionary for the optimizer.step() call.

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – whether to pin the layout when reducing gradients. See xm.all_reduce for details.


The same value returned by the optimizer.step() call.

torch_xla.core.xla_model.save(data, file_or_path, master_only=True, global_master=False)[source]

Saves the input data into a file.

The saved data is transferred to PyTorch CPU device before being saved, so a following torch.load() will load CPU data. Care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).

  • data – The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, …).

  • file_or_path – The destination for the data saving operation. Either a file path or a Python file object. If master_only is False the path or file objects must point to different destinations as otherwise all the writes from the same host will override each other.

  • master_only (bool, optional) – Whether only the master device should save the data. If False, the file_or_path argument should be a different file or path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True

  • global_master (bool, optional) – When master_only is True this flag controls whether every host’s master (if global_master is False) saves the content, or only the global master (ordinal 0). Default: False

  • sync (bool, optional) – Whether to synchronize all replicas after saving tensors. If True, all replicas must call xm.save or the main process will hang.

torch_xla.core.xla_model.rendezvous(tag, payload=b'', replicas=[])[source]

Waits for all the mesh clients to reach the named rendezvous.

Note: PJRT does not support the XRT mesh server, so this is effectively an alias to xla_rendezvous.

  • tag (string) – The name of the rendezvous to join.

  • payload (bytes, optional) – The payload to be sent to the rendezvous.

  • replicas (list, python:int) – The replica ordinals taking part of the rendezvous. Empty means all replicas in the mesh. Default: []


The payloads exchanged by all the other cores, with the payload of core ordinal i at position i in the returned tuple.

torch_xla.core.xla_model.mesh_reduce(tag, data, reduce_fn)[source]

Performs an out-of-graph client mesh reduction.

  • tag (string) – The name of the rendezvous to join.

  • data – The data to be reduced. The reduce_fn callable will receive a list with the copies of the same data coming from all the mesh client processes (one per core).

  • reduce_fn (callable) – A function which receives a list of data-like objects and returns the reduced result.


The reduced value.

torch_xla.core.xla_model.set_rng_state(seed, device=None)[source]

Sets the random number generator state.

  • seed (python:integer) – The state to be set.

  • device (string, optional) – The device where the RNG state needs to be set. If missing the default device seed will be set.


Gets the current running random number generator state.


device (string, optional) – The device whose RNG state needs to be retrieved. If missing the default device seed will be set.


The RNG state, as integer.

torch_xla.core.xla_model.get_memory_info(device: device) MemoryInfo[source]

Retrieves the device memory usage.


device – The device whose memory information are requested.


MemoryInfo dict with memory usage for the given device.

torch_xla.core.xla_model.get_stablehlo(tensors=None) str[source]

Get StableHLO for the computation graph in string format.

If tensors is not empty, the graph with tensors as outputs will be dump. If tensors is empty, the whole computation graph will be dump. TODO(lsy323): When tensors is empty, the some intermediate tensors will also be dump as outputs. Need further investigation.

For inference graph, it is recommended to pass the model outputs to tensors. For training graph, it is not straightforward to identify the “outputs”. Using empty tensors is recommended.

To enable source line info in StableHLO, please set env var XLA_HLO_DEBUG=1.


tensors (list[torch.Tensor], optional) – Tensors that represent the output/root of the StableHLO graph.


StableHLO Module in string format.

torch_xla.core.xla_model.get_stablehlo_bytecode(tensors=None) bytes[source]

Get StableHLO for the computation graph in bytecode format.

If tensors is not empty, the graph with tensors as outputs will be dump. If tensors is empty, the whole computation graph will be dump. TODO(lsy323): When tensors is empty, the some intermediate tensors will also be dump as outputs. Need further investigation.

For inference graph, it is recommended to pass the model outputs to tensors. For training graph, it is not straightforward to identify the “outputs”. Using empty tensors is recommended.


tensors (list[torch.Tensor], optional) – Tensors that represent the output/root of the StableHLO graph.


StableHLO Module in bytecode format.


class torch_xla.distributed.parallel_loader.ParallelLoader(loader, devices, batchdim=0, batches_per_execution=1, loader_prefetch_size=8, device_prefetch_size=4, host_to_device_transfer_threads=1, input_sharding=None)[source]

Wraps an existing PyTorch DataLoader with background data upload.

  • loader (torch.utils.data.DataLoader) – The PyTorch DataLoader to be wrapped.

  • devices (torch.device…) – The list of devices where the data has to be sent. The i-th sample returned by the loader will be sent to devices[i % len(devices)].

  • batchdim (python:int, optional) – The dimension which is holding the batch size. Default: 0

  • loader_prefetch_size (python:int, optional) – The max capacity of the queue used by the thread which is reading samples from the loader, to be processed by the worker threads which upload data to the devices. Default: 8

  • device_prefetch_size (python:int, optional) – The max size of the per-device queues, where the worker threads deposit tensors which have already been sent to devices. Default: 4

  • host_to_device_transfer_threads (python:int, optional) – The number of threads that work in parallel to transfer data from loader queue to device queue. Default: 1

  • input_sharding (ShardingSpec, optional) – Sharding spec to apply to compatible input tensors after loading. Default: None


Retrieves the loader iterator object for the given device.


device (torch.device) – The device whole loader is being requested.


The loader iterator object for the device. This is not a torch.utils.data.DataLoader interface, but a Python iterator which returns the same tensor data structure as returned by the wrapped torch.utils.data.DataLoader, but residing on XLA devices.

torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]

Enables multi processing based replication.

  • fn (callable) – The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in args.

  • args (tuple) – The arguments for fn. Default: Empty tuple

  • nprocs (python:int) – The number of processes/devices for the replication. At the moment, if specified, can be either 1 or the maximum number of devices.

  • join (bool) – Whether the call should block waiting for the completion of the processes which have being spawned. Default: True

  • daemon (bool) – Whether the processes being spawned should have the daemon flag set (see Python multi-processing API). Default: False

  • start_method (string) – The Python multiprocessing process creation method. Default: spawn


The same object returned by the torch.multiprocessing.spawn API. If nprocs is 1 the fn function will be called directly, and the API will return None.


torch_xla.distributed.spmd.mark_sharding(t: Union[Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Optional[Union[Tuple, int, str]]]) XLAShardedTensor[source]

Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. :param t: input tensor to be annotated with partition_spec. :type t: Union[torch.Tensor, XLAShardedTensor] :param mesh: describes the logical XLA device topology and the underlying device IDs. :type mesh: Mesh :param partition_spec: A tuple of device_mesh dimension index or

None. Each index is an int, str if the mesh axis is named, or tuple of int or str. This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). When a tuple is specified, the corresponding input tensor axis will be sharded along all logical axes in the tuple. Note that the order the mesh axes are specified in the tuple will impact the resulting sharding.

  • example (For) –

  • row-wise (we can shard an 8x10 tensor 4-way) –

  • column-wise. (and replicate) –

  • torch.randn (>> input =) –

  • = (>> partition_spec) –

  • =

  • dynamo_custom_op (bool) – if set to True, it calls the dynamo custom op variant of mark_sharding to make itself recognizeable and traceable by dynamo.

Examples ——————————— mesh_shape = (4, 2) num_devices = xr.global_runtime_device_count() device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, (‘x’, ‘y’))

# 4-way data parallel input = torch.randn(8, 32).to(xm.xla_device()) xs.mark_sharding(input, mesh, (0, None))

# 2-way model parallel linear = nn.Linear(32, 10).to(xm.xla_device()) xs.mark_sharding(linear.weight, mesh, (None, 1))

torch_xla.distributed.spmd.clear_sharding(t: Union[Tensor, XLAShardedTensor]) Tensor[source]

Clear sharding annotation from the input tensor and return a cpu casted tensor.

torch_xla.distributed.spmd.set_global_mesh(mesh: Mesh)[source]
class torch_xla.distributed.spmd.Mesh(device_ids: Union[ndarray, List], mesh_shape: Tuple[int, ...], axis_names: Optional[Tuple[str, ...]] = None)[source]

Describe the logical XLA device topology mesh and the underlying resources.

  • device_ids (Union[np.ndarray, List]) – A raveled list of devices (IDs) in a custom order. The list is reshaped to an mesh_shape array, filling the elements using C-like index order.

  • mesh_shape (Tuple[python:int, ...]) – A int tuple describing the logical topology shape of the device mesh, and each element describes the number of devices in the corresponding axis.

  • axis_names (Tuple[str, ...]) – A sequence of resource axis names to be assigned to the dimensions of the devices argument. Its length should match the rank of devices.

Example: ——————————— mesh_shape = (4, 2) num_devices = len(xm.get_xla_supported_devices()) device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, (‘x’, ‘y’)) mesh.get_logical_mesh() >> array([[0, 1],

[2, 3], [4, 5], [6, 7]])

mesh.shape() >> OrderedDict([(‘x’, 4), (‘y’, 2)])

class torch_xla.distributed.spmd.HybridMesh(*, ici_mesh_shape: Tuple[int, ...], dcn_mesh_shape: Optional[Tuple[int, ...]] = None, axis_names: Optional[Tuple[str, ...]] = None)[source]
Creates a hybrid device mesh of devices connected with ICI and DCN networks.

The shape of logical mesh should be ordered by increasing network-intensity e.g. [replica, data, model] where mdl has the most network communication requirements.

  • ici_mesh_shape – shape of the logical mesh for inner connected devices.

  • dcn_mesh_shape – shape of logical mesh for outer connected devices.


# This example is assuming 2 slices of v4-8. ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor) dcn_mesh_shape = (2, 1, 1)

mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, (‘data’,’fsdp’,’tensor’)) print(mesh.shape()) >> OrderedDict([(‘data’, 2), (‘fsdp’, 4), (‘tensor’, 1)])

class torch_xla.distributed.spmd.ShardingSpec(mesh: torch_xla.distributed.spmd.xla_sharding.Mesh, partition_spec: Tuple[Optional[int]], minibatch: Optional[bool] = False)[source]


torch_xla.experimental.eager_mode(enable: bool)[source]

Configure torch_xla’s default executation mode.

Under eager mode only functions that was `torch_xla.compile`d will be traced and compiled. Other torch ops will be executed eagerly.


Compile the func with Lazy Tensor.

Return the optimized function that takes exact same input. Compile will run the target func under the tracing mode using Lazy tensor.



Retrieves a string containing the full metrics and counters report.

torch_xla.debug.metrics.short_metrics_report(counter_names: Optional[list] = None, metric_names: Optional[list] = None)[source]

Retrieves a string containing the full metrics and counters report.

  • counter_names (list) – The list of counter names whose data needs to be printed.

  • metric_names (list) – The list of metric names whose data needs to be printed.


Retrieves all the currently active counter names.


Returns the value of an active counter.


name (string) – The name of the counter whose value needs to be retrieved.


The counter value as integer.


Retrieves all the currently active metric names.


Returns the data of an active metric.


name (string) – The name of the metric whose data needs to be retrieved.


The metric data, which is a tuple of (TOTAL_SAMPLES, ACCUMULATOR, SAMPLES). The TOTAL_SAMPLES is the total number of samples which have been posted to the metric. A metric retains only a given number of samples (in a circular buffer). The ACCUMULATOR is the sum of the samples over TOTAL_SAMPLES. The SAMPLES is a list of (TIME, VALUE) tuples.

Beginner’s Guide to PyTorch/XLA

This document provides a high-level overview of PyTorch XLA and illustrates a few examples how PyTorch code is converted to run on XLA devices (e.g. TPUs). This is not a complete solution, and additional changes may be required depending on the specific code. However, this document should serve as a starting point for the conversion process.

Basic high-level understanding of some XLA details

This section provides a brief overview of the basic details of PyTorch XLA,

which should help readers better understand the required modifications and optimizations of code. It is supplement to the API guide described here.

Unlike regular PyTorch, which executes code line by line and does not block execution until the value of a PyTorch tensor is fetched, PyTorch XLA works differently. It iterates through the python code and records the operations on (PyTorch) XLA tensors in an intermediate representation (IR) graph until it encounters a barrier (discussed below). This process of generating the IR graph is referred to as tracing (LazyTensor tracing or code tracing). PyTorch XLA then converts the IR graph to a lower-level machine-readable format called HLO (High-Level Opcodes). HLO is a representation of a computation that is specific to the XLA compiler and allows it to generate efficient code for the hardware that it is running on. HLO is fed to the XLA compiler for compilation and optimization. Compilation is then cached by PyTorch XLA to be reused later if/when needed. The compilation of the graph is done on the host (CPU), which is the machine that runs the Python code. If there are multiple XLA devices, the host compiles the code for each of the devices separately except when using SPMD (single-program, multiple-data). For example, v4-8 has one host machine and four devices. In this case the host compiles the code for each of the four devices separately. In case of pod slices, when there are multiple hosts, each host does the compilation for XLA devices it is attached to. If SPMD is used, then the code is compiled only once (for given shapes and computations) on each host for all the devices.


For more details and examples, please refer to the LazyTensor guide.

The operations in the IR graph are executed only when values of tensors are needed. This is referred to as evaluation or materialization of tensors. Sometimes this is also called lazy evaluation and it can lead to significant performance improvements.

The synchronous operations in Pytorch XLA, like printing, logging, checkpointing or callbacks block tracing and result in slower execution. In the case when an operation requires a specific value of an XLA tensor, e.g. print(xla_tensor_z), tracing is blocked until the value of that tensor is available to the host. Note that only the part of the graph responsible for computing that tensor value is executed. These operations do not cut the IR graph, but they trigger host-device communication through TransferFromDevice, which results in slower performance.

A barrier is a special instruction that tells XLA to execute the IR graph and materialize the tensors. This means that the PyTorch XLA tensors will be evaluated, and the results will be available to the host. The user-exposed barrier in Pytorch XLA is xm.mark_step(), which breaks the IR graph and results in code execution on the XLA devices. One of the key properties of xm.mark_step is that unlike synchronous operations it does not block the further tracing while the device is executing the graph. However, it does block access to the values of the tensors that are being materialized.

The example in the LazyTensor guide illustrates what happens in a simple case of adding two tensors. Now, suppose we have a for loop that adds XLA tensors and uses the value later:

for x, y in tensors_on_device:
    z += x + y

Without a barrier, the Python tracing will result in a single graph that wraps the addition of tensors len(tensors_on_device) times. This is because the for loop is not captured by the tracing, so each iteration of the loop will create a new subgraph corresponding to the computation of z += x+y and add it to the graph. Here is an example when len(tensors_on_device)=3.


However, introducing a barrier at the end of the loop will result in a smaller graph that will be compiled once during the first pass inside the for loop and will be reused for the next len(tensors_on_device)-1 iterations. The barrier will signal to the tracing that the graph traced so far can be submitted for execution, and if that graph has been seen before, a cached compiled program will be reused.

for x, y in tensors_on_device:
    z += x + y

In this case there will be a small graph that is used len(tensors_on_device)=3 times.


It is important to highlight that in PyTorch XLA Python code inside for loops is traced and a new graph is constructed for each iteration if there is a barrier at the end. This can be a significant performance bottleneck.

The XLA graphs can be reused when the same computation happens on the same shapes of tensors. If the shapes of the inputs or intermediate tensors change, then the XLA compiler will recompile a new graph with the new tensor shapes. This means that if you have dynamic shapes or if your code does not reuse tensor graphs, running your model on XLA will not be suitable for that use case. Padding the input into a fixed shape can be an option to help avoid dynamic shapes. Otherwise, a significant amount of time will be spent by the compiler on optimizing and fusing operations which will not be used again.

The trade-off between graph size and compilation time is also important to consider. If there is one large IR graph, the XLA compiler can spend a lot of time on optimization and fusion of the ops. This can result in a very long compilation time. However, the later execution may be much faster, due to the optimizations that were performed during compilation.

Sometimes it is worth breaking the IR graph with xm.mark_step(). As explained above, this will result in a smaller graph that can be reused later. However making graphs smaller can reduce optimizations that otherwise could be done by the XLA compiler.

Another important point to consider is MPDeviceLoader. Once your code is running on an XLA device, consider wrapping the torch dataloader with XLA MPDeviceLoader which preloads data to the device to improve performance and includes xm.mark_step() in it. The latter automatically breaks the iterations over batches of data and sends them for execution. Note, if you are not using MPDeviceLoader, you might need to set barrier=True in the optimizer_step() to enable xm.mark_step() if running a training job or explicitly adding xm.mark_step().

TPU Setup

Create TPU with base image to use nightly wheels or from the stable release by specifying the RUNTIME_VERSION.

export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-8 # v4-16, v4-32, …
export RUNTIME_VERSION=tpu-vm-v4-pt-2.0 # or tpu-vm-v4-base
export TPU_NAME=your_tpu_name

gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \

If you have a single host VM (e.g. v4-8), you can ssh to your vm and run the following commands from the vm directly. Otherwise, in case of TPU pods, you can use --worker=all --command="" similar to

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone=us-central2-b \
--worker=all \
--command="pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl"

Next, if you are using base image, install nightly packages and required libraries

pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl
​​pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
sudo apt-get install libopenblas-dev -y

sudo apt-get update && sudo apt-get install libgl1 -y # diffusion specific

Converting code to PyTorch XLA

General guidelines to modify your code:

  • Replace cuda with xm.xla_device()

  • Remove progress bar, printing that would access the XLA tensor values

  • Reduce logging and callbacks that would access the XLA tensor values

  • Wrap data loader with MPDeviceLoader

  • Profile to further optimize the code

Remember: each case is unique so you might need to do something different for each case.

Example 1. Stable Diffusion inference in PyTorch Lightning on a Single TPU Device

As a first example consider the inference code of the stable diffusion model in PyTorch Lightning which can be run from command line as

python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse"

For your reference, the diff of modifications described below can be found here. Let’s go over them step by step. As in the general guideline above, start with changes related to cuda device. This inference code is written to run on GPUs and cuda can be found in multiple places. Start making changes by removing model.cuda() from this line, and precision_scope from here. Additionally, replace the cuda device in this line with the xla device similar to the code below:

Next, this particular configuration of the model is using FrozenCLIPEmbedder, therefore we will modify this line as well. For simplicity we will directly define the device in this tutorial, but you can pass the device value to the function as well.

import torch_xla.core.xla_model as xm
self.device = xm.xla_device()

Another place in the code that has cuda specific code is DDIM scheduler. Add import torch_xla.core.xla_model as xm on top of the file then replace these lines

if attr.device != torch.device("cuda"):
   attr = attr.to(torch.device("cuda"))


device = xm.xla_device()
attr = attr.to(torch.device(device))

Next, you can reduce device (TPU) and host (CPU) communication by removing print statements, disabling progress bars, and reducing or removing callbacks and logging. These operations require the device to stop executing, falling back to the CPU, executing the logging/callbacks, and then returning to the device. This can be a significant performance bottleneck, especially on large models.

After making these changes, the code will run on TPUs. However, the performance will be very slow. This is because the XLA compiler tries to build a single (huge) graph that wraps the number of inference steps (in this case, 50) as there is no barrier inside the for loop. It is difficult for the compiler to optimize the graph, and this leads to significant performance degradation. As discussed above, breaking the for loop with the barrier (xm.mark_step()) will result in a smaller graph that is easier for the compiler to optimize. This will also allow the compiler to reuse the graph from the previous step, which can improve performance.

Now the code is ready to run on TPUs in a reasonable time. More optimization and analysis can be done by capturing a profile and investigating further. However, this is not covered here.

Note: if you are running on v4-8 TPU, then you have 4 available XLA (TPU) devices. Running the code as above will only use one XLA device. In order to run on all 4 devices you need to use xmp.spawn() function to spawn the code on all the devices. We will discuss an xmp.spawn in the next example.

Example 2. HF Stable Diffusion Inference

Now, consider using Stable Diffusion Inference in the HuggingFace diffusers library for both the SD-XL and 2.1 versions of the model. For your reference, the changes described below can be found in this repo. You can clone the repo and run the inference using the following command on your TPU VM:

(vm)$ git clone https://github.com/pytorch-tpu/diffusers.git
(vm)$ cd diffusers/examples/text_to_image/
(vm)$ python3 inference_tpu_single_device.py

Since there is no bf16 version of the SD-XL model available, you can use the XLA_USE_BF16=1 flag to convert all values to bf16 and speed up training.

(vm)$ XLA_USE_BF16=1 python3 inference_tpu_single_device.py # uses sd-xl version


(vm)$ python3 inference_tpu_multidevice.py # uses 2.1 version

(already includes torch.bfloat16 in the 2.1 version of the model).

Warning: watch out for caveats highlighted here.

Running on a Single TPU device

This section describes the changes that need to be made to the text_to_image inference example code to run it on TPUs.

The original code uses Lora for inference, but this tutorial will not use it. Instead, we will set the model_id argument to stabilityai/stable-diffusion-xl-base-0.9 when initializing the pipeline. We will also use the default scheduler (DPMSolverMultistepScheduler). However, similar changes can be made to the other schedulers as well.

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install . # pip install -e .

cd examples/text_to_image/
pip install -r requirements.txt
pip install invisible_watermark transformers accelerate safetensors

(If accelerate is not found, log out, log back in.)

Log in to HF and agree to the sd-xl 0.9 license on the model card. Next, go to account→settings→access token and generate a new token. Copy the token and run the following command with that specific token value on your vm

(vm)$ huggingface-cli login --token _your_copied_token__

The HuggingFace readme provides PyTorch code that is written to run on GPUs. To run it on TPUs, the first step is to change the CUDA device to an XLA device. This can be done by replacing the line pipe.to("cuda") with the following lines:

import torch_xla.core.xla_model as xm
device = xm.xla_device()

Additionally, it is important to note that the first time you run inference with XLA, it will take a long time to compile. For example, compilation time for stable diffusion XL model inference from HuggingFace can take about an hour to compile, whereas the actual inference may take only 5 seconds, depending on the batch size. Likewise, a GPT-2 model can take about 10-15 mins to compile, after which the training epoch time becomes much faster. This is because XLA builds a graph of the computation that will be performed, and then optimizes this graph for the specific hardware that it is running on. However, once the graph has been compiled, it can be reused for subsequent inferences, which will be much faster. Therefore, if you are only running inference once, you may not benefit from using XLA. However, if you are running inference multiple times, or if you are running inference on a list of prompts, you will start to see the advantages of XLA after the first few inferences. For example, if you run inference on a list of 10 prompts, the first inference (maybe two1) may take a long time to compile, but the remaining inference steps will be much faster. This is because XLA will reuse the graph that it compiled for the first inference.

If you try to run the code without making any additional changes, you will notice that the compilation time is very long (>6 hours). This is because the XLA compiler tries to build a single graph for all of the scheduler steps at once similar to what we have discussed in the previous example. To make the code run faster, we need to break the graph up into smaller pieces with xm.mark_step() and reuse them in the next steps. This happens inside the pipe.__call__ function in these lines. Disabling the progress bar, removing callbacks and adding xm.mark_step() at the end of the for loop speeds up the code significantly. Changes are provided in this commit.

Additionally, the self.scheduler.step() function, which by default uses the DPMSolverMultistepScheduler scheduler, has a few issues that are described in the PyTorch XLA caveats. The .nonzero() and .item() calls in this function send requests to the CPU for tensor evaluation, which trigger device-host communication. This is not desirable, as it can slow down the code. In this particular case, we can avoid these calls by passing the index to the function directly. This will prevent the function from sending requests to the CPU, and will improve the performance of the code. Changes are available in this commit. The code now is ready to be run on TPUs.

Profiling and performance analysis

To further investigate the performance of the model, we can profile it using the profiling guide. As a rule of thumb, the profiling script should be run with the maximum batch size that fits into the memory for optimal memory usage. It also helps to overlap tracing of the code with device execution which leads to more optimal device usage. The duration of profiling should be long enough to capture at least one step. Good performance of the model on TPUs means that device-host communication is minimized and the device is constantly running processes with no idle time.

Starting a server in the inference_tpu_*.py file and running capture_profile.py script as described in the guide will give us information on processes that run on the devices. Currently, only one XLA device is profiled. To better understand the TPU idle time (gaps in the profile), profiling traces (xp.Trace()) should be added to the code. The xp.Trace() measures the time it takes to trace the python code on the host machine wrapped with the trace. For this example, xp.Trace() traces were added inside the pipeline and the U-net model to measure the time to run specific sections of the code on the host (CPU).

If the gaps in the profile are due to Python code tracing that happens on the host, then this might be a bottleneck and there is no further straightforward optimization that can be done. Otherwise, the code should be analyzed further to understand the caveats and improve the performance further. Note that you cannot xp.Trace() wrap portions of the code where xm.mark_step() is called.

To illustrate this we can look at already captured profiles that were uploaded to tensorboard following the profiling guide.

Starting from Stable Diffusion model version 2.1

If we capture a profile without inserting any traces, we will see the following:

Alt text

The single TPU device on v4-8, which has two cores, appears to be busy. There are no significant gaps in their usage, except for a small one in the middle. If we scroll up to try to find which process is occupying the host machine, we will not find any information. Therefore, we will add xp.traces to the pipeline file as well as the U-net function. The latter may not be useful for this particular use case, but it does demonstrate how traces can be added in different places and how their information is displayed in TensorBoard.

If we add traces and re-capture the profile with the largest batch size that can fit on the device (32 in this case), we will see that the gap in the device is caused by a Python process that is running on the host machine.

Alt text Alt text

We can use the appropriate tool to zoom in on the timeline and see which process is running during that period. This is when the Python code tracing happens on the host, and we cannot improve the tracing further at this point.

Now, let’s examine the XL version of the model and do the same thing. We will add traces to the pipeline file in the same way that we did for the 2.1 version and capture a profile.

Alt text

This time, in addition to the large gap in the middle, which is caused by the pipe_watermark tracing, there are many small gaps between the inference steps within this loop.

First look closer into the large gap that is caused by pipe_watermark. The gap is preceded with TransferFromDevice which indicates that something is happening on the host machine that is waiting for computation to finish before proceeding. Looking into watermark code, we can see that tensors are transferred to cpu and converted to numpy arrays in order to be processed with cv2 and pywt libraries later. Since this part is not straightforward to optimize, we will leave this as is.

Now if we zoom in on the loop, we can see that the graph within the loop is broken into smaller parts because the TransferFromDevice operation happens.

Alt text

If we investigate the U-Net function and the scheduler, we can see that the U-Net code does not contain any optimization targets for PyTorch/XLA. However, there are .item() and .nonzero() calls inside the scheduler.step. We can rewrite the function to avoid those calls. If we fix this issue and rerun a profile, we will not see much difference. However, since we have reduced the device-host communication that was introducing smaller graphs, we allowed the compiler to optimize the code better. The function scale_model_input has similar issues, and we can fix these by making the changes we made above to the step function. Overall, since many of the gaps are caused from python level code tracing and graph building, these gaps are not possible to optimize with the current version of PyTorch XLA, but we may see improvements in the future when dynamo is enabled in PyTorch XLA.

Running on Multiple TPU Devices

To use multiple TPU devices, you can use the xmp.spawn function to spawn the function you ran on a single device to multiple devices. The xmp.spawn function will start processes on multiple TPU devices and sync them when needed. This can be done by passing the index argument to the function that runs on a single device. For example,

import torch_xla.distributed.xla_multiprocessing as xmp

def my_function(index):
  # function that runs on a single device

xmp.spawn(my_function, args=(0,), nprocs=4)

In this example, the my_function function will be spawned on 4 TPU devices on v4-8, with each device being assigned an index from 0 to 3.

This file illustrates how xmp.spawn can be used to run stable diffusion 2.1 version on multiple TPU devices. For this version similar to the above changes were made to the pipeline file.

Running on Pods

Once you have the code for running on a single host device, there is no further change needed. You can create the TPU pod, for example, by following these instructions. Then run your script with

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --worker=all \
  --command="python3 your_script.py"

0 and 1 are magic numbers in XLA and treated as constants in the HLO. So if there is a random number generator in the code that can generate these values, the code will compile for each value separately. This can be disabled with XLA_NO_SPECIAL_SCALARS=1 environment variable.


Note that the information in this section is subject to be removed in future releases of the PyTorch/XLA software, since many of them are peculiar to a given internal implementation which might change.

Sanity Check

Before performing any in depth debugging, we want to do a sanity check on the installed PyTorch/XLA.

Check PyTorch/XLA Version

PyTorch and PyTorch/XLA version should match. Check out our README for more detials on versions available.

vm:~$ python
>>> import torch
>>> import torch_xla
>>> print(torch.__version__)
>>> print(torch_xla.__version__)

Perform A Simple Calculation

vm:~$ export PJRT_DEVICE=TPU
vm:~$ python3
>>> import torch
>>> import torch_xla.core.xla_model as xm
>>> t1 = torch.tensor(100, device=xm.xla_device())
>>> t2 = torch.tensor(200, device=xm.xla_device())
>>> print(t1 + t2)
tensor(300, device='xla:0')

Run Resnet With Fake Data

For nightly

vm:~$ git clone https://github.com/pytorch/xla.git
vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data

For release version x.y, you want to use the branch rx.y. For example if you installed 2.1 release, you should do

vm:~$ git clone --branch r2.1 https://github.com/pytorch/xla.git
vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data

If you can get the resnet to run we can conclude that torch_xla is installed correctly.

Performance Debugging

To diagnose performance issues, we can use the execution metrics and counters provided by PyTorch/XLA The first thing to check when model is slow is to generate a metrics report.

Metrics report is extremely helpful in diagnosing issues. Please try to include it in your bug report sent to us if you have it.

PyTorch/XLA Debugging Tool

You can enable the PyTorch/XLA debugging tool by setting PT_XLA_DEBUG_LEVEL=2, which provides a couple useful debugging features. You can also lower the debug level to 1 to slip the execution analysis.

Perform A Auto-Metrics Analysis

The debugging tool will analyze the metrics report and provide a summary. Some example output would be

pt-xla-profiler: CompileTime too frequent: 21 counts during 11 steps
pt-xla-profiler: TransferFromDeviceTime too frequent: 11 counts during 11 steps
pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward,  Please open a GitHub issue with the above op lowering requests.
pt-xla-profiler: CompileTime too frequent: 23 counts during 12 steps
pt-xla-profiler: TransferFromDeviceTime too frequent: 12 counts during 12 steps

Compilation & Execution Analysis

The debugging tool will analyze every compilation and execution for your model. Some example output would be

Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis:   mark_step in parallel loader at step end
Compilation Analysis: Graph Info:
Compilation Analysis:   Graph Hash: c74c3b91b855b2b123f833b0d5f86943
Compilation Analysis:   Number of Graph Inputs: 35
Compilation Analysis:   Number of Graph Outputs: 107
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis:   mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055)
Compilation Analysis:   next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44)
Compilation Analysis:   __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32)
Compilation Analysis:   train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:48)
Compilation Analysis:   start_training (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:65)
Compilation Analysis:   <module> (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:73)
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================

Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 1.548000 GB
Post Compilation Analysis: Graph output size: 7.922460 GB
Post Compilation Analysis: Aliased Input size: 1.547871 GB
Post Compilation Analysis: Intermediate tensor size: 12.124478 GB
Post Compilation Analysis: Compiled program size: 0.028210 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   mark_step in parallel loader at step end
Execution Analysis: Graph Info:
Execution Analysis:   Graph Hash: c74c3b91b855b2b123f833b0d5f86943
Execution Analysis:   Number of Graph Inputs: 35
Execution Analysis:   Number of Graph Outputs: 107
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis:   mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055)
Execution Analysis:   next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44)
Execution Analysis:   __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32)
Execution Analysis:   train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:48)
Execution Analysis:   start_training (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:65)
Execution Analysis:   <module> (/workspaces/dk3/pytorch/xla/examples/train_decoder_only_base.py:73)
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================

Some common causes of Compilation/Executation are

  1. User manually call mark_step.

  2. Parallel loader call mark_step for every x (configurable) batch.

  3. Exiting a profiler StepTrace region.

  4. Dynamo decide to compile/execute the graph.

  5. User trying to access(often due to logging) the value of a tensor before the mark_step.

The executation caused by 1-4 are expected, and we want to avoid 5 by either reduce the frequency of accessing tensor values or manually add a mark_step before accessing.

Users should expect to see this Compilation Cause + Executation Cause pairs for first couple steps. After the model stabilize users should expect to only see Execution Cause(you can disable execution analysis by PT_XLA_DEBUG_LEVEL=1). To use PyTorch/XLA efficiently, we expect the same models code to be run for every step and compilation only happen once for every graph. If you keep seeing Compilation Cause, you should try to dump the IR/HLO following this section and compare the graphs for each step and understand the source of the differences.

Following section will explain how to get and understand a more detail metrics report.

Get A Metrics Report

Put the following line in your program to generate a report:

import torch_xla.debug.metrics as met

# For short report that only contains a few key metrics.
# For full report that includes all metrics.

Understand The Metrics Report

The report includes things like:

  • how many time we issue XLA compilations and time spent on issuing.

  • how many times we execute and time spent on execution

  • how many device data handles we create/destroy etc.

This information is reported in terms of percentiles of the samples. An example is:

Metric: CompileTime
  TotalSamples: 202
  Counter: 06m09s401ms746.001us
  ValueRate: 778ms572.062us / second
  Rate: 0.425201 / second
  Percentiles: 1%=001ms32.778us; 5%=001ms61.283us; 10%=001ms79.236us; 20%=001ms110.973us; 50%=001ms228.773us; 80%=001ms339.183us; 90%=001ms434.305us; 95%=002ms921.063us; 99%=21s102ms853.173us

We also provide counters, which are named integer variables which track internal software status. For example:

Counter: CachedSyncTensors
  Value: 395

In this report, any counter that starts with aten:: indicates a context switch between the XLA device and CPU, which can be a potential performance optimization area in the model code.

Counters are useful to understand which operations are routed back to the CPU engine of PyTorch. They are fully qualified with their C++ namespace:

Counter: aten::nonzero
  Value: 33

If you see aten:: ops other than nonzero and _local_scalar_dense, that usually means a missing lowering in PyTorch/XLA. Feel free to open a feature request for it on GitHub issues.

Clear The Metrics Report

If you want to clear the metrics between steps/epochs, you can use

import torch_xla.debug.metrics as met


PyTorch/XLA + Dynamo Debugging Tool

You can enable the PyTorch/XLA + Dynamo debugging tool by setting XLA_DYNAMO_DEBUG=1.

Performance Profiling

To profile your workload in depth to understand bottlenecks please check the following resources:

Simple Benchmarking

Take a look at ``examples/train_resnet_benchmark.py` <https://github.com/pytorch/xla/blob/master/examples/train_resnet_benchmark.py>`_ for how to benchmark a PyTorch/XLA model.

Known Performance Caveats

PyTorch/XLA behaves semantically like regular PyTorch and XLA tensors share the full tensor interface with CPU & GPU tensors. However, constraints in XLA/hardware and the lazy evaluation model suggest certain patterns might result in bad performance.

If your model shows bad performance, keep in mind the following caveats:

  1. XLA/TPU yield degraded performance with too many recompilations.

    XLA compilation is expensive. PyTorch/XLA automatically recompiles the graph every time new shapes are encountered. Usually models should stabilize within a few steps and you can see huge speedup for the rest of training.

    In order to avoid recompilations, not only must shapes be constant, but computations across XLA devices in all hosts should also be constant.

    Possible sources:

    • Direct or indirect uses of nonzero introduce dynamic shapes; for example, masked indexing base[index] where index is a mask tensor.

    • Loops with a different number of iterations between steps can result in different execution graphs, thus require recompilations.


    • Tensor shapes should be the same between iterations, or a low number of shape variations should be used.

    • Pad tensors to fixed sizes when possible.

  2. Certain operations don’t have native translations to XLA.

    For these operations PyTorch/XLA automatically transfers to the CPU memory, evaluates on CPU, and transfers the result back to the XLA device. Doing too many such operations during the training step can lead to significant slowdowns.

    Possible sources:

    • The item() operation explicitly asks to evaluate the result. Don’t use it unless it’s necessary.


    • For most ops we can lower them to XLA to fix it. Checkout metrics report section to find out the missing ops and open a feature request on GitHub.

    • Even when a PyTorch tensor is known as a scalar, avoid using tensor.item(). Keep it as a tensor and use tensor operations on it.

    • Use torch.where to substitute control flow when applicable. E.g. The control flow with item() used in clip_grad*norm* is problematic and impacts performance, so we have patched clip_grad_norm_ by calling torch.where instead, which gives us a dramatic performance improvement. .. code-block:: python

      … else:

      device = parameters[0].device total_norm = torch.zeros([], device=device if parameters else None) for p in parameters:

      param_norm = p.grad.data.norm(norm_type) ** norm_type total_norm.add_(param_norm)

      total_norm = (total_norm ** (1. / norm_type))

      clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) for p in parameters:

      p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))

  3. Iterators in ``torch_xla.distributed.data_parallel`` may drop the last few batches in the input iterator.

    This is to make sure we do the same amount of work on all XLA devices.


    • When dataset is small, and there are too few steps, this may result in a no-op epoch. Therefore, it is better to use small batch sizes in those cases.

XLA Tensor Quirks

  1. XLA tensor internals are opaque. XLA tensors always appear to be contiguous and without storage. Networks should not try to check the strides of XLA tensors.

  2. XLA tensors should be moved to the CPU before saving them. Saving XLA tensors directly causes them to be loaded back on the device(s) they were saved from. If a device is unavailable at load time then the load will fail. Moving XLA tensors to the CPU before saving them lets you decide which device(s) to put the loaded tensors on. This is necessary if you want to load the tensors on a machine without XLA devices. Care should be taken moving the XLA tensors to the CPU before saving them, however, as moving tensors across device types does not preserve view relationships. Instead, views should be reconstructed as necessary after the tensors are loaded.

  3. Copying an XLA Tensor with Python’s copy.copy returns a deep copy, not a shallow copy. Use a view of an XLA tensor to get a shallow copy of it.

  4. Handling shared weights. Modules can share weights by setting the Parameters of one module to another. This “tying” of module weights should be done AFTER the modules are moved to an XLA device. Otherwise two independent copies of the shared tensor will be made on the XLA device.

More Debugging Tools

We don’t expect users to use tools in this section to debug their models. But we might ask for them when you submit a bug report since they provide additional information that metrics report doesn’t have.

  • print(torch_xla._XLAC._get_xla_tensors_text([res])) where res is the result tensor prints out the IR.

  • print(torch_xla._XLAC._get_xla_tensors_hlo([res])) where res is the result tensor prints out the generated XLA HLO.

Note these functions must be called prior to mark_step(), otherwise the tensor will already be materialized.

Environment Variables

There are also a number of environment variables which control the behavior of the PyTorch/XLA software stack.

Setting such variables will cause different degrees of performance degradation, so they should only be enabled for debugging.

  • XLA_IR_DEBUG: Enables the Python stack trace to be captured where creating IR nodes, hence allowing to understand which PyTorch operation was responsible for generating the IR.

  • XLA_HLO_DEBUG: Enables the Python stack frame captured when _XLA_IRDEBUG is active, to be propagated to the XLA HLO metadata.

  • XLA_SAVE_TENSORS_FILE: The path to a file which will be used to dump the IR graphs during execution. Note that the file can become really big if the option is left enabled and the PyTorch program let run for long time. The graphs are appended to the file, so to have a clean sheet from run to run, the file should be explicitly removed.

  • XLA_SAVE_TENSORS_FMT: The format of the graphs stored within the _XLA_SAVE_TENSORSFILE file. Can be text (the default), dot (the Graphviz format) or hlo.

  • XLA_FLAGS=--xla_dump_to: If set to =/tmp/dir_name, XLA compiler will dump the unoptimized and optimzed HLO per compilation.

  • XLA_METRICS_FILE: If set, the path to a local file where the internal metrics will be saved at every step. Metrics will be appended to the file, if already existing.

  • XLA_SAVE_HLO_FILE: If set, the path to a local file where, in case of compilation/execution error, the offending HLO graph will be saved.

  • XLA_SYNC_WAIT: Forces the XLA tensor sync operation to wait for its completion, before moving to the next step.

  • XLA_USE_EAGER_DEBUG_MODE: Forces the XLA tensor to execute eagerly, meaning compile and execute the torch operations one by one. This is useful to bypass the long compilation time but overall step time will be a lot slower and memory usage will be higher since all compiler optimizaiton will be skipped.

  • TF_CPP_LOG_THREAD_ID: If set to 1, the TF logs will show the thread ID helping with debugging multithreaded processes.

  • TF_CPP_VMODULE: Environment variable used for TF VLOGs and takes the form of TF_CPP_VMODULE=name=value,.... Note that for VLOGs you must set TF_CPP_MIN_LOG_LEVEL=0.

  • TF_CPP_MIN_LOG_LEVEL: Level to print messages for. TF_CPP_MIN_LOG_LEVEL=0 will turn on INFO logging, TF_CPP_MIN_LOG_LEVEL=1 WARNING and so on. Our PyTorch/XLA TF_VLOG uses tensorflow::INFO level by default so to see VLOGs set TF_CPP_MIN_LOG_LEVEL=0.

  • XLA_DUMP_HLO_GRAPH: If set to =1 in case of a compilation or execution error the offending HLO graph will be dumped as part of the runtime error raised by xla_util.cc.

Common Debugging Environment Variables Combinations

  • Record the graph execution in the IR format

  • Record the graph execution in the HLO format

  • Show debugging VLOG for runtime and graph compilation/execution

    TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=5,pjrt_computation_client=3"

Reproducing PyTorch/XLA CI/CD unit test failures.

You may see some test failures for a PR such as:

To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8

Running this directly in the command line does not work. You need to set the environment variable TORCH_TEST_DEVICES to your local pytorch/xla/test/pytorch_test_base.py. For example:

TORCH_TEST_DEVICES=/path/to/pytorch/xla/test/pytorch_test_base.py PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8 should work.

PJRT Runtime

PyTorch/XLA has migrated from the TensorFlow-based XRT runtime to the PJRT runtime used by JAX.

If you encounter a bug with PJRT, please file an issue on GitHub with the runtime tag.

New features in PyTorch/XLA r2.1:

  • PJRT is stable in PyTorch/XLA r2.1!

  • Public runtime APIs have moved from torch_xla.experimental.pjrt to torch_xla.runtime.

    • The pjrt:// init method has been renamed to xla://, and it is registered by torch_xla.distributed.xla_backend.

    • The previous torch_xla.experimental.* names are still available in this release for compatibility.

  • torchrun is now supported when using init_method='xla://'.

  • New plugins for XPU and Neuron via the PJRT C API.

New features in PyTorch/XLA r2.0:

  • PJRT will be configured by default if you don’t pass in any other runtime configuration. If you continue to set XRT configuration (XRT_TPU_CONFIG), this change has no impact

  • New TPU runtime implementation in libtpu improves performance by up to 30%.

  • New xm.rendezvous implementation that scales to thousands of TPU cores

  • [experimental] torch.distributed support for TPU v2 and v3, including pjrt:// init_method


  • To use the PJRT preview runtime, set the PJRT_DEVICE environment variable to CPU, TPU, or CUDA

  • In XRT, all distributed workloads are multiprocess, with one process per device. On TPU v2 and v3 in PJRT, workloads are multiprocess and multithreaded (4 processes with 2 threads each), so your workload should be thread-safe. See Multithreading on TPU v2/v3 and the Multiprocessing section of the API guide for more information. Key differences to keep in mind:

    • To initialize a model in a thread-safe way, either broadcast the parameters across replicas after initialization (torch_xla.experimental.pjrt.broadcast_master_param) or load each replica’s parameters from a common checkpoint.

    • For other random number generation, use torch.Generator where possible. The global torch RNG is not thread-safe, even if you set the same torch.manual_seed across replicas.

    • To use torch.distributed, import torch_xla.experimental.pjrt_backend and use the xla:// init_method.

    • These steps are optional for GPU and TPU v4.

Sample diff from XRT to PJRT:

 import os

 import torch
 import torch.nn as nn
 from torch.nn.parallel import DistributedDataParallel as DDP
 import torch.optim as optim
 import torch.distributed as dist
 import torch_xla.core.xla_model as xm
 import torch_xla.distributed.parallel_loader as pl
 import torch_xla.distributed.xla_backend
 import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.runtime as xr

 def _mp_fn(index):
   device = xm.xla_device()
-  dist.init_process_group('xla', rank=xm.get_ordinal(), world_size=xm.xrt_world_size())
+  dist.init_process_group('xla', init_method='xla://')

   model = nn.Linear(128, 10).to(device)

+  # Optional for TPU v4 and GPU
+  xm.broadcast_master_param(model)
   model = DDP(model, gradient_as_bucket_view=True)

   loss_fn = nn.MSELoss()
   optimizer = optim.SGD(model.parameters(), lr=.001)

   for i in range(10):
     data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)

     output = model(data)
     loss = loss_fn(output, target)


   # Print mean parameters so we can confirm they're the same across replicas
   print([p.mean() for p in model.parameters()])

 if __name__ == '__main__':
-  os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'

+  # Recommended: set PJRT_DEVICE to your local device type
+  os.environ['PJRT_DEVICE'] = 'TPU'



  • Simple runtime configuration: just set PJRT_DEVICE to TPU, CPU, or CUDA and start using XLA! Or, let PJRT select a device automatically based on your environment.

  • Improved performance: reduced overhead from gRPC means faster end-to-end execution. On TorchBench 2.0, we observed a >35% improvement in training time on TPU v4.

  • Easy pod execution: just copy your code to each TPU worker, and execute them all at the same time with gcloud compute tpus tpuvm ssh --worker=all.

  • Better scaling: removes XRT’s limitation on parameter sizes and supports up to 2048 TPU chips.


To start using PJRT with PyTorch/XLA, all you need to do is set the PJRT_DEVICE environment variable. If you’re working on a TPU v2 or v3, keep reading to learn about the differences between TPU v2 and v3 and v4.


On any machine with PyTorch/XLA installed, you can run our MNIST example on CPU like this:

PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data


To create a new TPU with PyTorch/XLA r2.0 installed:

gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT

On a v4-8, you can run our ResNet50 example like this:

git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

By default, PJRT will use all TPU chips. To use only one TPU chip, configure TPU_PROCESS_BOUNDS and TPU_VISIBLE_CHIPS:

TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1


On TPU Pods, use gcloud to run your command on each TPU in parallel:

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"


You can also use Docker to run your workload in a container with PyTorch/XLA preinstalled:

export DOCKER_IMAGE=gcr.io/...

# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"

# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"

Note that docker run requires privileged access to the host (--privileged) to expose the TPU device to the container. Docker on TPU pods is only supported with host networking --net=host at this time. See the Cloud TPU documentation for more information.


Single-node GPU training

To use GPUs with PJRT, simply set PJRT_DEVICE=CUDA and configure GPU_NUM_DEVICES to the number of devices on the host. For example:

PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

You can also use torchrun to initiate the single-node multi-GPU training. For example,

PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

In the above example, --nnodes means how many machines (physical machines or VMs) to be used (it is 1 since we do single-node training). --nproc-per-node means how many GPU devices to be used.

Multi-node GPU training

Note that this feature only works for cuda 12+. Similar to how PyTorch uses multi-node training, you can run the command as below:

--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
  • --nnodes: how many GPU machines to be used.

  • --node_rank: the index of the current GPU machines. The value can be 0, 1, …, ${NUMBER_GPU_VM}-1.

  • --nproc_per_node: the number of GPU devices to be used on the current machine.

  • –rdzv_endpoint: the endpoint of the GPU machine with node_rank==0, in the form host:port`. The``hostwill be the internal IP address. Theport` can be any available port on the machine. For single-node training/inference, this parameter can be omitted.

For example, if you want to train on 2 GPU machines: machine_0 and machine_1, on the first GPU machine machine_0, run

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

On the second GPU machine, run

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

the difference between the 2 commands above are --node_rank and potentially --nproc_per_node if you want to use different number of GPU devices on each machine. All the rest are identical. For more information about torchrun, please refer to this page.

Differences from XRT

Although in most cases we expect PJRT and XRT to work mostly interchangeably from the end-user’s perspective (especially on TPU v4), there are some subtle differences that are important to keep in mind. Importantly, XRT was designed around the TPU Node architecture, so it will always spawn a client and a server process, even on TPU VMs. Thus, every batch of inputs has additional latency from serializing and deserializing data to send it over the network.

PJRT uses the local device directly with no intermediate server process. In the default configuration, PJRT will create one process per TPU chip, or 4 processes per TPU host. See the Cloud TPU documentation for more information about TPU architecture.

  • Performance gains are possible for workloads constrained overhead from .

  • Under XRT, the server process is the only process that interacts with the TPU devices, and client processes don’t have direct access to the TPU devices. When profiling a single-host TPU (e.g. v3-8 or v4-8), you would normally see 8 device traces (one for each TPU core). With PJRT, each process has one chip, and a profile from that process will show only 2 TPU cores.

    • For the same reason, profiling does not work on TPU Pods with XRT, because the server process runs independently from the user’s model code. PJRT does not have that constraint, so it is possible to profile 2 TPU cores per process in a TPU Pod.

  • PJRT only supports the TPU VM architecture and we have no plans to support the TPU Node architecture with PJRT.

  • Runtime configuration is significantly simpler with PJRT. xla_dist is not required to run TPU Pod workloads. Instead, copy your code to each TPU host ([gcloud compute tpus tpu-vm scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)) and run the code on each host in parallel (e.g. [gcloud compute tpus tpu-vm ssh --workers=all --command="PJRT_DEVICE=TPU python run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh))

  • xm.rendezvous has been reimplemented using XLA-native collective communication to enhance stability on large TPU pods. See below for more details.

Multithreading on TPU v2/v3

On TPU v2 and v3, distributed workloads always run multithreaded, since each TPU core exposes two TPU cores as devices and only one process may open a TPU chip at a time. In its default configuration, xmp.spawn automatically spawns as many processes as possible (4 per TPU host) and creates two threads per process (one per TPU core).

Note: on TPU v4, each TPU chip is represented as one PyTorch device, so distributed workloads will run across 4 processes, each with only one thread. This is identical to XRT’s behavior.

In most cases, this will not require substantial changes to your existing code. The main change you will have to make in most cases is to model initialization. Because torch‘s global RNG is shared between threads, results will vary between threads and runs even if you set torch.manual_seed to the same value in every replica. To get consistent parameters between replicas, either use torch_xla.experimental.pjrt.broadcast_master_param to broadcast one replica’s parameters to all other replicas, or load each replica’s parameters from a common checkpoint.

Changes to xm.rendezvous

New in PyTorch/XLA r2.0

With XRT, worker 0 runs a mesh master service, and all processes on all workers connect to that service over gRPC. In practice, we found that running a single mesh master process was unreliable on TPU pods with thousands of chips due to the number of inbound connections to worker 0. A single client process timing out could cause a failure and force the entire workload to restart.

Thus, we have reimplemented xm.rendezvous with native XLA collective communication, which is much more stable and well-tested on large TPU pods. This imposes two new constraints compared to the XRT implementation:

  • Because the payload has to become part of the XLA graph, xm.mark_step is called both before and after the data is transferred. Calling xm.rendezvous in the middle of model code may force an unwanted compilation.

  • Because XLA does not permit collective operations to run on a subset of workers, all workers must participate in the rendezvous.

If you require the old behavior of xm.rendezvous (i.e. communicating data without altering the XLA graph and/or synchronizing a subset of workers), consider using ``torch.distributed.barrier` <https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier>`_ or ``torch.distributed.all_gather_object` <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object>`_ with a gloo process group. If you are also using the xla torch.distributed backend, you can use torch.new_group to create a gloo subgroup. See this example from the PyTorch documentation. Keep in mind these constraints:

  • torch.distributed is not fully supported on TPU v2/v3. Only a subset of operations with the xla backend are implemented, and gloo will likely not work as expected in a multithreaded context.

  • In our experiments, gloo does not scale well to thousands of TPU chips, so expect this alternative to be less reliable than using xm.rendezvous with PJRT at large scales.

PJRT and torch.distributed

New in PyTorch/XLA r2.0

When using PJRT with torch.distributed and [torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md) we strongly recommend using the new xla:// init_method, which automatically finds the replica IDs, world size, and master IP by querying the runtime. For example:

import torch
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.experimental import pjrt

# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend

def _all_gather(index: int):
  # No need to pass in `rank` or `world_size`
  dist.init_process_group('xla', init_method='xla://')

  t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
  output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
  dist.all_gather(output, t)


if __name__ == '__main__':

Note: Although the xla:// init_method is not required on TPU v4, it is still recommended. If you use env://, MASTER_ADDR must be set to IP host that has device 0, which is not always worker 0. The xla:// init_method finds this IP automatically.

Note: For TPU v2/v3, you still need to import torch_xla.experimental.pjrt_backend, as TPU v2/v3 support in torch.distributed is still experimental.

For more information about using DistributedDataParallel on PyTorch/XLA, see ``ddp.md` <./ddp.md>`_ on TPU V4. For an example that uses DDP and PJRT together, run the following example script on a TPU:

PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1


TorchBench shows improvements in average training time across tasks with PJRT compared to XRT, with an average improvement of over 35% on TPU v4-8. The benefits vary significantly by task and model type, ranging from 0% to 175%. The following chart shows the breakdown by task:


New TPU runtime

New in PyTorch/XLA r2.0

The PyTorch/XLA r2.0 release introduces support for the PJRT Plugin API, used to access the new TFRT-based TPU runtime in libtpu. This is now the default runtime when PJRT_DEVICE=TPU is set. The legacy StreamExecutor-based TPU runtime used in 1.13 will still be available with PJRT_DEVICE=TPU_LEGACY in the 2.0 release, but it will be removed in a future version. If you encounter an issue that only happens on TPU and not TPU_LEGACY, please file an issue on GitHub.

In most cases, we expect performance to be similar between the two runtimes, but in some cases, the new runtime may be up to 30% faster. The following chart shows the breakdown by task:

TFRT vs StreamExecutor

Note: the improvements shown in this chart are also included in the PJRT vs XRT comparison.

TorchDynamo(torch.compile) integration in PyTorch XLA

TorchDynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. It provides a clean API for compiler backends to hook in and its biggest feature is to dynamically modify Python bytecode right before it is executed. In the pytorch/xla 2.0 release, PyTorch/XLA provided an experimental backend for the TorchDynamo for both inference and training.

The way that XLA bridge works is that Dynamo will provide a TorchFX graph when it recognizes a model pattern and PyTorch/XLA will use existing Lazy Tensor technology to compile the FX graph and return the compiled function.


Support for PyTorch/XLA and Dynamo currently exists by adding the backend='openxla' argument to torch.compile. For example:

import torch
import torch_xla.core.xla_model as xm

def add(a, b):
  a_xla = a.to(xm.xla_device())
  b_xla = b.to(xm.xla_device())
  return a_xla + b_xla

compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))


Here is a small code example of running resnet18 with torch.compile

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  dynamo_resnet18 = torch.compile(
    xla_resnet18, backend='openxla')
  for data, _ in loader:
    with torch.no_grad():
      output = dynamo_resnet18(data)

With the torch.compile you will see that PyTorch/XLA only traces the resent18 model once during the init time and executes the compiled binary every time dynamo_resnet18 is invoked, instead of tracing the model every time. Here is a inference speed analysis to compare Dynamo and Lazy using torch bench on Cloud TPU v4-8

resnet18 | 2.59 resnet50 | 2.64 resnext50_32x4d | 1.91 alexnet | 1.28 mobilenet_v2 | 18.62 mnasnet1_0 | 2.68 vgg16 | 1.33 BERT_pytorch | 7.49 squeezenet1_1 | 2.29 timm_vision_transformer | 3.52 geomean | 3.04


PyTorch/XLA also supports Dynamo for training, but it is experimental and we are working with the PyTorch Compiler team to iterate on the implementation. Here is an example of training a resnet18 with torch.compile

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target, optimizer):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  dynamo_train_model = torch.compile(
        train_model, backend='openxla')
  for data, target in loader:
    xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2)
    output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer)

We expect to extract and execute 3 graphs per training step instead of 1 graph per training step if you use the Lazy tensor. Here is a training speed analysis to compare Dynamo and Lazy using a torch bench on Cloud TPU v4-8.

resnet50 | 1.33 resnet18 | 1.33 BERT_pytorch | 3.07 resnext50_32x4d | 1.43 alexnet | 1.12 mobilenet_v2 | 1.4 mnasnet1_0 | 1.19 vgg16 | 0.81 timm_vision_transformer | 1.87 squeezenet1_1 | 1.41 geomean | 1.41

NOTE: We run each model’s fwd and bwd for a single step and then collect the e2e time. In the real world we will run multiple steps at each training job which can easily hide the tracing cost from execution(since it is async). Lazy Tensor will have much better performance in that scenario.

Feature gaps

There is one gap we want to call out that are preventing us from using the TorchDynamo on larger scale models.

  1. TorchDynamo will trace forward and backward into separate graphs. For PyTorch/XLA it is important to let the XLA compiler see the whole step as one graph to best optimize the speed. There is also a fixed overhead to launch every device execution which make executing multiple graphs per training step less ideal.

This gap compared to Lazy Tensor makes it less efficient in real world training use cases, especially the tracing cost can be overlapped with the execution in training.

Take away

TorchDynamo provides a really promising way for the compiler backend to hide the complexity from the user and easily retrieve the modeling code in a graph format. Compared with PyTorch/XLA’s traditional Lazy Tensor way of extracting the graph, TorchDynamo can skip the graph tracing for every iteration, hence providing a much better inference response time.

Most models supported by PyTorch/XLA, have seen significant speedup when running inference with the new dynamo-xla bridge. Our community is working hard to expand the set of supported models. Regarding the training feature gaps mentioned above, the PyTorch/XLA community is super excited to improve the training gap in our upcoming development work. The team continues to heavily invest in TorchDynamo and work with the upstream to mature the training story.

Fully Sharded Data Parallel (FSDP) in PyTorch XLA

Fully Sharded Data Parallel (FSDP) in PyTorch XLA is a utility for sharding Module parameters across data-parallel workers.

Example usage:

import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()

It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters.


  • The XlaFullyShardedDataParallel class supports both the ZeRO-2 optimizer (sharding gradients and optimizer states) and the ZeRO-3 optimizer (sharding parameters, gradients, and optimizer states) in https://arxiv.org/abs/1910.02054.

    • The ZeRO-3 optimizer should be implemented via nested FSDP with reshard_after_forward=True. See test/test_train_mp_mnist_fsdp_with_ckpt.py and test/test_train_mp_imagenet_fsdp.py for an example.

    • For large models that cannot fit into a single TPU memory or the host CPU memory, one should interleave submodule construction with inner FSDP wrapping. See ``FSDPViTModel` <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>`_ for an example.

  • a simple wrapper checkpoint_module is provided (based on torch_xla.utils.checkpoint.checkpoint from https://github.com/pytorch/xla/pull/3524) to perform gradient checkpointing over a given nn.Module instance. See test/test_train_mp_mnist_fsdp_with_ckpt.py and test/test_train_mp_imagenet_fsdp.py for an example.

  • Auto-wrapping submodules: instead of manually nested FSDP wrapping, one can also specify an auto_wrap_policy argument to automatically wrap the submodules with inner FSDP. size_based_auto_wrap_policy in torch_xla.distributed.fsdp.wrap is an example of auto_wrap_policy callable, this policy wraps layers with the number of parameters larger than 100M. transformer_auto_wrap_policy in torch_xla.distributed.fsdp.wrap is an example of auto_wrap_policy callable for transformer-like model architectures.

For example, to automatically wrap all torch.nn.Conv2d submodules with inner FSDP, one can use:

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

Additionally, one can also specify an auto_wrapper_callable argument to use a custom callable wrapper for the submodules (the default wrapper is just the XlaFullyShardedDataParallel class itself). For example, one can use the following to apply gradient checkpointing (i.e. activation checkpointing/rematerialization) to each auto-wrapped submodule.

from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
    checkpoint_module(m), *args, **kwargs)
  • When stepping the optimizer, directly call optimizer.step and do not call xm.optimizer_step. The latter reduces the gradient across ranks, which is not needed for FSDP (where the parameters are already sharded).

  • When saving model and optimizer checkpoints during training, each training process needs to save its own checkpoint of the (sharded) model and optimizer state dicts (use master_only=False and set different paths for each rank in xm.save). When resuming, it needs to load the checkpoint for the corresponding rank.

  • Please also save model.get_shard_metadata() along with model.state_dict() as follows and use consolidate_sharded_model_checkpoints to stitch the sharded model checkpoints together into a full model state dict. See test/test_train_mp_mnist_fsdp_with_ckpt.py for an example. .. code-block:: python3

    ckpt = {

    ‘model’: model.state_dict(), ‘shard_metadata’: model.get_shard_metadata(), ‘optimizer’: optimizer.state_dict(),

    } ckpt_path = f’/tmp/rank-{xm.get_ordinal()}-of-{xm.xrt_world_size()}.pth’ xm.save(ckpt, ckpt_path, master_only=False)

  • The checkpoint consolidation script can also be launched from the command line as follows. .. code-block:: bash

    # consolidate the saved checkpoints via command line tool python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank--of-.pth”

The implementation of this class is largely inspired by and mostly follows the structure of fairscale.nn.FullyShardedDataParallel in https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html. One of the biggest differences from fairscale.nn.FullyShardedDataParallel is that in XLA we don’t have explicit parameter storage, so here we resort to a different approach to free full parameters for ZeRO-3.

Example training scripts on MNIST and ImageNet


FSDP is available on PyTorch/XLA 1.12 release and newer nightly. Please refer to https://github.com/pytorch/xla#-available-images-and-wheels for installation guide.

Clone PyTorch/XLA repo

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/

Train MNIST on v3-8 TPU

It gets around 98.9 accuracy for 2 epochs:

python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --use_gradient_checkpointing

This script automatically tests checkpoint consolidation at the end. You can also manually consolidate the sharded checkpoints via

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"

Train ImageNet with ResNet-50 on v3-8 TPU

It gets around 75.9 accuracy for 100 epochs; download ImageNet-1k to /datasets/imagenet-1k:

python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \

You can also add --use_gradient_checkpointing (which needs to be used along with --use_nested_fsdp or --auto_wrap_policy) to apply gradient checkpointing on the residual blocks.

Example training scripts on TPU pod (with 10 billion parameters)

To train large models that cannot fit into a single TPU, one should apply auto-wrap or manually wrap the submodules with inner FSDP when building the entire model to implement the ZeRO-3 algorithm.

Please see https://github.com/ronghanghu/vit_10b_fsdp_example for an example of sharded training of a Vision Transformer (ViT) model using this XLA FSDP PR.

How to do DistributedDataParallel

This document shows how to use torch.nn.parallel.DistributedDataParallel in xla, and further describes its difference against the native xla data parallel approach.

Background / Motivation

Customers have long requested the ability to use PyTorch’s DistributedDataParallel API with xla. And here we enable it as an experimental feature.

How to use DistributedDataParallel

For those who switched from the PyTorch eager mode to XLA, here are all the changes you need to do to convert your eager DDP model into XLA model. We assume that you already know how to use XLA on a single device.

  1. Import xla specific distributed packages:

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
  1. Init xla process group similar to other process groups such as nccl and gloo.

dist.init_process_group("xla", rank=rank, world_size=world_size)
  1. Use xla specific APIs to get rank and world_size if you need to.

new_rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
  1. Pass gradient_as_bucket_view=True to the DDP wrapper.

ddp_model = DDP(model, gradient_as_bucket_view=True)
  1. Finally launch your model with xla specific launcher.


Here we have put everything together (the example is actually taken from the DDP tutorial). The way you code it is pretty similar to the eager experience. Just with xla specific touches on a single device plus the above five changes to your script.

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

# additional imports for xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the xla process group
    dist.init_process_group("xla", rank=rank, world_size=world_size)

def cleanup():

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 1000000)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(1000000, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank):
    # xla specific APIs to get rank, world_size.
    new_rank = xm.get_ordinal()
    assert new_rank == rank
    world_size = xm.xrt_world_size()

    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to XLA device
    device = xm.xla_device()
    model = ToyModel().to(device)
    # currently, graident_as_bucket_view is needed to make DDP work for xla
    ddp_model = DDP(model, gradient_as_bucket_view=True)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    outputs = ddp_model(torch.randn(20, 10).to(device))
    labels = torch.randn(20, 5).to(device)
    loss_fn(outputs, labels).backward()
    # xla specific API to execute the graph


def run_demo(demo_fn):
    # xla specific launcher

if __name__ == "__main__":


Resnet50 with fake data

The following results are collected with the command: python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1 on a TPU VM V3-8 environment with ToT PyTorch and PyTorch/XLA. And the statistical metrics are produced by using the script in this pull request. The unit for the rate is images per second.

Type Mean Median 90th % Std Dev CV
xm.optimizer_step 418.54 419.22 430.40 9.76 0.02
DDP 395.97 395.54 407.13 7.60 0.02

The performance difference between our native approach for distributed data parallel and DistributedDataParallel wrapper is: 1 - 395.97 / 418.54 = 5.39%. This result seems reasonable given the DDP wrapper introduces extra overheads on tracing the DDP runtime.

MNIST with fake data

The following results are collected with the command: python test/test_train_mp_mnist.py --fake_data on a TPU VM V3-8 environment with ToT PyTorch and PyTorch/XLA. And the statistical metrics are produced by using the script in this pull request. The unit for the rate is images per second.

Type Mean Median 90th % Std Dev CV
xm.optimizer_step 17864.19 20108.96 24351.74 5866.83 0.33
DDP 10701.39 11770.00 14313.78 3102.92 0.29

The performance difference between our native approach for distributed data parallel and DistributedDataParallel wrapper is: 1 - 14313.78 / 24351.74 = 41.22%. Here we compare 90th % instead since the dataset is small and first a few rounds are heavily impacted by data loading. This slowdown is huge but makes sense given the model is small. The additional DDP runtime tracing overhead is hard to amortize.

MNIST with real data

The following results are collected with the command: python test/test_train_mp_mnist.py --logdir mnist/ on a TPU VM V3-8 environment with ToT PyTorch and PyTorch/XLA.


And we can observe that the DDP wrapper converges slower than the native XLA approach even though it still achieves a high accuracy rate at 97.48% at the end. (The native approach achieves 99%.)


This feature is still experimental and under active development. Use it in cautions and feel free to file any bugs to the xla github repo. For those who are interested in the native xla data parallel approach, here is the tutorial.

Here are some of the known issues that are under investigation:

  • gradient_as_bucket_view=True needs to be enforced.

  • There are some issues while being used with torch.utils.data.DataLoader. ​​test_train_mp_mnist.py with real data crashes before exiting.

How to run with PyTorch/XLA:GPU

PyTorch/XLA enables PyTorch users to utilize the XLA compiler which supports accelerators including TPU, GPU, and CPU. This doc will go over the basic steps to run PyTorch/XLA on a nvidia GPU instances.

Create a GPU instance

You can either use a local machine with GPU attached or a GPU VM on the cloud. For example in Google Cloud you can follow this doc to create the GPU VM.

Environment Setup

Make sure you have cuda driver installed on the host.


Pytorch/XLA currently publish prebuilt docker images and wheels with cuda11.8/12.1 and python 3.8. We recommend users to create a docker container with corresponding config. For a full list of docker images and wheels, please refer to this doc.

sudo docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1

# Installing the NVIDIA Container Toolkit per https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
# For example
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
  && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
    sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
    sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
sudo apt-get update
sudo apt-get install -y nvidia-container-toolkit

# Configuring the NVIDIA Container Toolkit
sudo nvidia-ctk runtime configure --runtime=docker
sudo systemctl restart docker

sudo docker run --shm-size=16g --net=host --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1 bin/bash
sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash

Note that you need to restart the docker to make gpu devices visible in the docker container. After logging into the docker, you can use nvidia-smi to verify the device is setup correctly.

(pytorch) root@20ab2c7a2d06:/# nvidia-smi
Thu Dec  8 06:24:29 2022
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    38W / 300W |      0MiB / 16384MiB |      1%      Default |
|                               |                      |                  N/A |

| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|  No running processes found                                                 |

Check environment variable

Make sure PATH and LD_LIBRARY_PATH environment variables account for cuda. Please do a echo $PATH and echo $LD_LIBRARY_PATH to verify. If not, please follow link to do so. Example:

echo "export PATH=\$PATH:/usr/local/cuda-12.1/bin" >> ~/.bashrc
echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> ~/.bashrc
source ~/.bashrc


**NOTE:** The wheel file is compatible only with x86_64 linux based architecutre. To check the architecture of your linux system, execute the following command:

uname -a
pip3 install torch==2.3.0
# GPU whl for python 3.10 + cuda 12.1
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl

Wheels for other Python version and CUDA version can be found here.

Run some simple models

In order to run below examples, you need to clone the pytorch/xla repository.

MP_ImageNet Example

This example uses ImageNet. It is included in what we already cloned in our Docker container.

(pytorch) root@20ab2c7a2d06:/# export GPU_NUM_DEVICES=1 PJRT_DEVICE=CUDA
(pytorch) root@20ab2c7a2d06:/# git clone --recursive https://github.com/pytorch/xla.git
(pytorch) root@20ab2c7a2d06:/# python xla/test/test_train_mp_imagenet.py --fake_data
==> Preparing data..
Epoch 1 train begin 06:12:38
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89059 Rate=2.82 GlobalRate=2.82 Time=06:13:23
| Training Device=xla:0/0 Epoch=1 Step=20 Loss=6.79297 Rate=117.16 GlobalRate=45.84 Time=06:13:36
| Training Device=xla:0/0 Epoch=1 Step=40 Loss=6.43628 Rate=281.16 GlobalRate=80.49 Time=06:13:43
| Training Device=xla:0/0 Epoch=1 Step=60 Loss=5.83108 Rate=346.88 GlobalRate=108.82 Time=06:13:49
| Training Device=xla:0/0 Epoch=1 Step=80 Loss=4.99023 Rate=373.62 GlobalRate=132.43 Time=06:13:56
| Training Device=xla:0/0 Epoch=1 Step=100 Loss=3.92699 Rate=384.33 GlobalRate=152.40 Time=06:14:02
| Training Device=xla:0/0 Epoch=1 Step=120 Loss=2.68816 Rate=388.35 GlobalRate=169.49 Time=06:14:09

ResNet Example

This example uses ResNet.

(pytorch) root@20ab2c7a2d06:/# python3 /xla/examples/train_resnet_base.py
1:35PM UTC on Jun 08, 2024
epoch: 1, step: 0, loss: 6.887794017791748, rate: 8.746502586051985
epoch: 1, step: 10, loss: 6.877807140350342, rate: 238.4789458412044
epoch: 1, step: 20, loss: 6.867819786071777, rate: 329.86095958663503
epoch: 1, step: 30, loss: 6.857839584350586, rate: 367.3038003653586
epoch: 1, step: 40, loss: 6.847847938537598, rate: 381.53141087190835
epoch: 1, step: 50, loss: 6.837860584259033, rate: 387.80462249591113
epoch: 1, step: 260, loss: 6.628140926361084, rate: 391.135639565343
epoch: 1, step: 270, loss: 6.618192195892334, rate: 391.6901797745233
epoch: 1, step: 280, loss: 6.608224391937256, rate: 391.1602680460045
epoch: 1, step: 290, loss: 6.598264217376709, rate: 391.6731498290759
Epoch 1 train end  1:36PM UTC


AMP is very useful on GPU training and PyTorch/XLA reuse Cuda’s AMP rule. You can checkout our mnist example and imagenet example. Note that we also used a modified version of optimizers to avoid the additional sync between device and host.

Develop PyTorch/XLA on a GPU instance (build PyTorch/XLA from source with GPU support)

  1. Inside a GPU VM, create a docker container from a development docker image. For example:

sudo docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1

# Installing the NVIDIA Container Toolkit per https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
# For example
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
  && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
    sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
    sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
sudo apt-get update
sudo apt-get install -y nvidia-container-toolkit

# Configuring the NVIDIA Container Toolkit
sudo nvidia-ctk runtime configure --runtime=docker
sudo systemctl restart docker

sudo docker run --shm-size=16g --net=host --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1
sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash
  1. Build PyTorch and PyTorch/XLA from source.

Make sure PATH and LD_LIBRARY_PATH environment variables account for cuda. See the above for more info.

git clone https://github.com/pytorch/pytorch.git
cd pytorch
USE_CUDA=1 python setup.py install

git clone https://github.com/pytorch/xla.git
cd xla
XLA_CUDA=1 python setup.py install
  1. Verify if PyTorch and PyTorch/XLA have been installed successfully.

If you can run the tests in the section Run some simple models successfully, then PyTorch and PyTorch/XLA should have been installed successfully.

PyTorch/XLA SPMD User Guide

In this user guide, we discuss how GSPMD is integrated in PyTorch/XLA, and provide a design overview to illustrate how the SPMD sharding annotation API and its constructs work.

What is PyTorch/XLA SPMD?

GSPMD is an automatic parallelization system for common ML workloads. The XLA compiler will transform the single device program into a partitioned one with proper collectives, based on the user provided sharding hints. This feature allows developers to write PyTorch programs as if they are on a single large device without any custom sharded computation ops and/or collective communications to scale.


*Figure 1. Comparison of two different execution strategies, (a) for non-SPMD and (b) for SPMD.*

How to use PyTorch/XLA SPMD?

Here is an simple example of using SPMD

import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh

# Enable XLA SPMD execution mode.

# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))

t = torch.randn(8, 4).to(xm.xla_device())

# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)

Let’s explain these concepts one by one


In order to use SPMD, you need to enable it via xr.use_spmd(). In SPMD mode there is only one logical device. Distributed computation and collective is handled by the mark_sharding. Note that user can not mix SPMD with other distributed libraries.


For a given cluster of devices, a physical mesh is a representation of the interconnect topology.

  1. mesh_shape is a tuple that will be multiplied to the total number of physical devices.

  2. device_ids is almost always np.array(range(num_devices)).

  3. Users are also encouraged to give each mesh dimension a name. In the above example, the first mesh dimension is the data dimension and the second mesh dimension is the model dimension.

You can also check more mesh info via

>>> mesh.shape()
OrderedDict([('data', 4), ('model', 1)])

Partition Spec

partition_spec has the same rank as the input tensor. Each dimension describes how the corresponding input tensor dimension is sharded across the device mesh. In the above example tensor t’s fist dimension is being sharded at data dimension and the second dimension is being sharded at model dimension.

User can also shard tensor that has different dimensions from the mesh shape.

t1 = torch.randn(8, 8, 16).to(device)
t2 = torch.randn(8).to(device)

# First dimension is being replicated.
xs.mark_sharding(t1, mesh, (None, 'data', 'model'))

# First dimension is being sharded at data dimension.
# model dimension is used for replication when omitted.
xs.mark_sharding(t2, mesh, ('data',))

# First dimension is sharded across both mesh axes.
xs.mark_sharding( t2, mesh, (('data', 'model'),))

Further Reading

  1. Example to use SPMD to express data parallism.

  2. Example to use SPMD to express FSDP(Fully Sharded Data Parallel).

  3. SPMD advanced topics

  4. Spmd Distributed Checkpoint

Fully Sharded Data Parallel via SPMD

Fully Sharded Data Parallel via SPMD or FSDPv2 is an utility that re-expresses the famous FSDP algorithm in SPMD. This is an experimental feature that aiming to offer a familiar interface for users to enjoy all the benefits that SPMD brings into the table. The design doc is here.

Please review the SPMD user guide before proceeding.

Example usage:

import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2

# Define the mesh following common SPMD practice
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'model'))

# Shard the input, and assume x is a 2D tensor.
x = xs.mark_sharding(x, mesh, ('fsdp', None))

# As normal FSDP, but an extra mesh is needed.
model = FSDPv2(my_module, mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()

It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters. Here is an example to autowrap each DecoderLayer.

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Apply FSDP sharding on each DecoderLayer layer.
auto_wrap_policy = functools.partial(
model = FSDPv2(
    model, mesh=mesh, auto_wrap_policy=auto_wrap_policy)

Sharding output

To ensure the XLA compiler correctly implements the FSDP algorithm, we need to shard both weights and activations. This means sharding the output of the forward method. Since the forward function output can vary, we offer shard_output to shard activations in cases where your module output doesn’t fall into one of these categories:

  1. A single tensor

  2. A tuple of tensors where the 0th element is the activation.

Example usage:

def shard_output(output, mesh):
    xs.mark_sharding(output.logits, mesh, ('fsdp', None, None))

model = FSDPv2(my_module, mesh, shard_output)

Gradient checkpointing

Currently, gradient checkpointing needs to be applied to the module before the FSDP wrapper. Otherwise, recursively loop into children modules will end up with infinite loop. We will fix this issue in the future releases.

Example usage:

from torch_xla.distributed.fsdp import checkpoint_module

model = FSDPv2(checkpoint_module(my_module), mesh)

HuggingFace Llama 2 Example

We have a fork of HF Llama 2 to demonstrate a potential integration here.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources