• Docs >
  • PyTorch on XLA Devices

PyTorch on XLA Devices

PyTorch runs on XLA devices, like TPUs, with the torch_xla package. This document describes how to run your models on these devices.

Creating an XLA Tensor

PyTorch/XLA adds a new xla device type to PyTorch. This device type works just like other PyTorch device types. For example, here’s how to create and print an XLA tensor:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())

This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing torch_xla initializes PyTorch/XLA, and xm.xla_device() returns the current XLA device. This may be a CPU or TPU depending on your environment.

XLA Tensors are PyTorch Tensors

PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors.

For example, XLA tensors can be added together:

t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)

Or matrix multiplied:


Or used with neural network modules:

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)

Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
# Input tensor is not an XLA tensor: torch.FloatTensor

will throw an error since the torch.nn.Linear module is on the CPU.

Running Models on XLA Devices

Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code. The following snippets highlight these lines when running on a single device, multiple devices with XLA multiprocessing, or multiple threads with XLA multithreading.

Running on a Single XLA Device

The following snippet shows a network training on a single XLA device:

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)

  xm.optimizer_step(optimizer, barrier=True)

This snippet highlights how easy it is to switch your model to run on XLA. The model definition, dataloader, optimizer and training loop can work on any device. The only XLA-specific code is a couple lines that acquire the XLA device and step the optimizer with a barrier. Calling xm.optimizer_step(optimizer, barrier=True) at the end of each training iteration causes XLA to execute its current graph and update the model’s parameters. See XLA Tensor Deep Dive for more on how XLA creates graphs and runs operations.

Running on Multiple XLA Devices with MultiProcessing

PyTorch/XLA makes it easy to accelerate training by running on multiple XLA devices. The following snippet shows how:

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
  device = xm.xla_device()
  para_loader = pl.ParallelLoader(train_loader, [device])

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in para_loader.per_device_loader(device):
    output = model(data)
    loss = loss_fn(output, target)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

There are three differences between this multidevice snippet and the previous single device snippet:

  • xmp.spawn() creates the processes that each run an XLA device.

  • ParallelLoader loads the training data onto each device.

  • xm.optimizer_step(optimizer) no longer needs a barrier. ParallelLoader automatically creates an XLA barrier that evalutes the graph.

The model definition, optimizer definition and training loop remain the same.

NOTE: It is important to note that, when using multi-processing, the user can start retrieving and accessing XLA devices only from within the target function of xmp.spawn() (or any function which has xmp.spawn() as parent in the call stack).

See the full multiprocessing example for more on training a network on multiple XLA devices with multiprocessing.

Running on Multiple XLA Devices with MultiThreading

Running on multiple XLA devices using processes (see above) is preferred to using threads. If, however, you want to use threads then PyTorch/XLA has a DataParallel interface. The following snippet shows the same network training with multiple threads:

import torch_xla.core.xla_model as xm
import torch_xla.distributed.data_parallel as dp

devices = xm.get_xla_supported_devices()
model_parallel = dp.DataParallel(MNIST, device_ids=devices)

def train_loop_fn(model, loader, device, context):
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in loader:
    output = model(data)
    loss = loss_fn(output, target)

for epoch in range(1, num_epochs + 1):
  model_parallel(train_loop_fn, train_loader)

The only differences between the multithreading and multiprocessing code are:

  • Multiple devices are acquired in the same process with xm.get_xla_supported_devices().

  • The model is wrapped in dp.DataParallel and passed both the training loop and dataloader.

See the full multithreading example for more on training a network on multiple XLA devices with multithreading.

XLA Tensor Deep Dive

Using XLA tensors and devices requires changing only a few lines of code. But even though XLA tensors act a lot like CPU and CUDA tensors their internals are different. This section describes what makes XLA tensors unique.

XLA Tensors are Lazy

CPU and CUDA tensors launch operations immediately or eagerly. XLA tensors, on the other hand, are lazy. They record operations in a graph until the results are needed. Deferring execution like this lets XLA optimize it. A graph of multiple separate operations might be fused into a single optimized operation, for example.

Lazy execution is generally invisible to the caller. PyTorch/XLA automatically constructs the graphs, sends them to XLA devices, and synchronizes when copying data between an XLA device and the CPU. Inserting a barrier when taking an optimizer step explicitly synchronizes the CPU and the XLA device.

XLA Tensors and bFloat16

PyTorch/XLA can use the bfloat16 datatype when running on TPUs. In fact, PyTorch/XLA handles float types (torch.float and torch.double) differently on TPUs. This behavior is controlled by the XLA_USE_BF16 environment variable:

  • By default both torch.float and torch.double are torch.float on TPUs.

  • If XLA_USE_BF16 is set, then torch.float and torch.double are both bfloat16 on TPUs.

  • If a PyTorch tensor has torch.bfloat16 data type, this will be directly mapped to the TPU bfloat16 (XLA BF16 primitive type).

XLA tensors on TPUs will always report their PyTorch datatype regardless of the actual datatype they’re using. This conversion is automatic and opaque. If an XLA tensor on a TPU is moved back to the CPU it will be converted from its actual datatype to its PyTorch datatype.

Memory Layout

The internal data representation of XLA tensors is opaque to the user. They do not expose their storage and they always appear to be contiguous, unlike CPU and CUDA tensors. This allows XLA to adjust a tensor’s memory layout for better performance.

Moving XLA Tensors to and from the CPU

XLA tensors can be moved from the CPU to an XLA device and from an XLA device to the CPU. If a view is moved then the data its viewing is copied to the other device and the view relationship is not preserved. Put another way, once data is copied to another device it has no relationship with its previous device or any tensors on it.

Saving and Loading XLA Tensors

XLA tensors should be moved to the CPU before saving, as in the following snippet:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)

This lets you put the loaded tensors on any available device.

Per the above note on moving XLA tensors to the CPU, care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).

A utility API is provided to save data by taking care of previously moving it to CPU:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

xm.save(model.state_dict(), path)

In case of multple devices, the above API will only save the data for the master device ordinal (0).

Directly saving XLA tensors is possible but not recommended. XLA tensors are always loaded back to the device they were saved from, and if that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch, is under active development and this behavior may change in the future.

Further Reading

Additional documentation is available at the PyTorch/XLA repo. More examples of running networks on TPUs are available here.



torch_xla.core.xla_model.xla_device(n=None, devkind=None)[source]

Returns a given instance of an XLA device.

  • n (python:int, optional) – The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of devkind will be returned.

  • devkind (string..., optional) – If specified, one of TPU, GPU or CPU (the ‘GPU’ XLA device is currently not implemented).


A torch.device with the requested instance.

torch_xla.core.xla_model.get_xla_supported_devices(devkind=None, max_devices=None)[source]

Returns a list of supported devices of a given kind.

  • devkind (string..., optional) – If specified, one of TPU, GPU or CPU (the ‘GPU’ XLA device is currently not implemented).

  • max_devices (python:int, optional) – The maximum number of devices to be returned of that kind.


The list of device strings.


Retrieves the number of devices which is taking part of the replication.


defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 1


The number of devices which is taking part of the replication.


Retrieves the replication ordinal of the current process.

The ordinals range from 0 to xrt_world_size() minus 1.


defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 0


The replication ordinal of the current process.


Retrieves the replication local ordinal of the current process.

The local ordinals range from 0 to the number of local devices minus 1.


defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 0


The replication local ordinal of the current process.


Checks whether the current process is the master ordinal (0).


local (bool) – Whether the local or global master ordinal should be checked. In case of multi-host replication, there is only one global master ordinal (host 0, device 0), while there are NUM_HOSTS local master ordinals. Default: True


A boolean indicating whether the current process is the master ordinal.

torch_xla.core.xla_model.all_reduce(reduce_type, inputs, scale=1.0, groups=[])[source]

Perform an inplace reduce operation on the input tensors.

  • reduce_type (string) – One of sum, mul, and, or, min and max.

  • inputs (list) – List of tensors to perform the all reduce op to.

  • scale (python:float) – A default scaling value to be applied after the reduce. Default: 1.0

  • groups (list) – Reserved.

torch_xla.core.xla_model.add_step_closure(closure, args=())[source]

Adds a closure to the list of the ones to be run at the end of the step.

Many times during model training there is the need to print/report (print to console, post to tensorboard, etc…) information which require the content of intermediary tensors to be inspected. Inspecting different tensors content in different points of the model code requires many executions and typically causes performance issues. Adding a step closure will ensure that it will be run after the barrier, when all the live tensors will be already materialized to device data. Live tensors which will include the ones captured by the closure arguments. So using add_step_closure() will ensure a single execution will be performed, even when multiple closures are queued, requiring multiple tensors to be inspected. Step closures will be run sequentially in the order they have been queued. Note that even though using this API the execution will be optimized, it is advised to throttle the printing/reporting events once every N steps.

  • closure (callable) – The function to be called.

  • args (tuple) – The arguments to be passed to the closure.

torch_xla.core.xla_model.optimizer_step(optimizer, barrier=False, optimizer_args={})[source]

Run the provided optimizer step and issue the XLA device step computation.

  • optimizer (torch.Optimizer) – The torch.Optimizer instance whose step() function needs to be called. The step() function will be called with the optimizer_args named arguments.

  • barrier (bool, optional) – Whether the XLA tensor barrier should be issued in this API. If using the PyTorch XLA ParallelLoader or DataParallel support, this is not necessary as the barrier will be issued by the XLA data loader iterator next() call. Default: False

  • optimizer_args (dict, optional) – Named arguments dictionary for the optimizer.step() call.


The same value returned by the optimizer.step() call.

torch_xla.core.xla_model.save(data, file_or_path, master_only=True, global_master=False)[source]

Saves the input data into a file.

The saved data is transfered to PyTorch CPU device before being saved, so a following torch.load() will load CPU data.

  • data – The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, …).

  • file_or_path – The destination for the data saving operation. Either a file path or a Python file object. If master_only is False the path or file objects must point to different destinations as otherwise all the writes from the same host will override each other.

  • master_only (bool, optional) – Whether only the master device should save the data. If False, the file_or_path argument should be a different file or path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True

  • global_master (bool, optional) – When master_only is True this flag controls whether every host’s master (if global_master is False) saves the content, or only the global master (ordinal 0). Default: False


class torch_xla.distributed.parallel_loader.ParallelLoader(loader, devices, batchdim=0, fixed_batch_size=False, loader_prefetch_size=8, device_prefetch_size=4)[source]

Wraps an existing PyTorch DataLoader with background data upload.

  • loader (torch.utils.data.DataLoader) – The PyTorch DataLoader to be wrapped.

  • devices (torch.device…) – The list of devices where the data has to be sent. The i-th sample returned by the loader will be sent to devices[i % len(devices)].

  • batchdim (python:int, optional) – The dimension which is holding the batch size. Default: 0

  • fixed_batch_size (bool, optional) – Ensures that all the batch sizes sent to the devices are of the same size. The original loader iteration stops as soon as a not matching batch size is found. Default: False

  • loader_prefetch_size (python:int, optional) – The max capacity of the queue used by the thread which is reading samples from the loader, to be processed by the worker threads which upload data to the devices. Default: 8

  • device_prefetch_size (python:int, optional) – The max size of the per-device queues, where the worker threads deposit tensors which have already been sent to devices. Default: 4


Retrieves the loader iterator object for the given device.


device (torch.device) – The device whole loader is being requested.


The loader iterator object for the device. This is not a torch.utils.data.DataLoader interface, but a Python iterator which returns the same tensor data structure as returned by the wrapped torch.utils.data.DataLoader, but residing on XLA devices.

class torch_xla.distributed.data_parallel.DataParallel(network, device_ids=None)[source]

Enable the execution of a model network in replicated mode using threads.

  • network (torch.nn.Module or callable) – The model’s network. Either a subclass of torch.nn.Module or a callable returning a subclass of torch.nn.Module.

  • device_ids (string… or torch.device…) – The list of devices on which the replication should happen. If the list is empty, the network will be run on PyTorch CPU device.

__call__(loop_fn, loader, fixed_batch_size=False, batchdim=0)[source]

Runs one EPOCH of training/test.

  • loop_fn (callable) – The function which will be called on each thread assigned to each device taking part of the replication. The function will be called with the def loop_fn(model, device_loader, device, context) signature. Where model is the per device network as passed to the DataParallel contructor. The device_loader is the ParallelLoader which will be returning samples for the current device. And the context is a per thread/device context which has the lifetime of the DataParallel object, and can be used by the loop_fn to store objects which needs to persist across different EPOCH.

  • fixed_batch_size (bool, optional) – Argument passed to the ParallelLoader constructor. Default: False

  • batchdim (python:int, optional) – The dimension in the samples returned by the loader holding the batch size. Default: 0


A list with the values returned by the loop_fn on each device.

torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]

Enables multi processing based replication.

  • fn (callable) – The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in args.

  • args (tuple) – The arguments for fn. Default: Empty tuple

  • nprocs (python:int) – The number of processes/devices for the replication. At the moment, if specified, can be either 1 or the maximum number of devices.

  • join (bool) – Whether the call should block waiting for the completion of the processes which have being spawned. Default: True

  • daemon (bool) – Whether the processes being spawned should have the daemon flag set (see Python multi-processing API). Default: False

  • start_method (string) – The Python multiprocessing process creation mathod. Default: spawn


The same object returned by the torch.multiprocessing.spawn API. If nprocs is 1 the fn function will be called directly, and the API will not return.


class torch_xla.utils.utils.SampleGenerator(data, sample_count)[source]

Iterator which returns multiple samples of a given input data.

Can be used in place of a PyTorch DataLoader to generate synthetic data.

  • data – The data which should be returned at each iterator step.

  • sample_count – The maximum number of data samples to be returned.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources