• Docs >
  • PyTorch on XLA Devices
Shortcuts

PyTorch on XLA Devices

PyTorch runs on XLA devices, like TPUs, with the torch_xla package. This document describes how to run your models on these devices.

Creating an XLA Tensor

PyTorch/XLA adds a new xla device type to PyTorch. This device type works just like other PyTorch device types. For example, here’s how to create and print an XLA tensor:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing torch_xla initializes PyTorch/XLA, and xm.xla_device() returns the current XLA device. This may be a CPU or TPU depending on your environment.

XLA Tensors are PyTorch Tensors

PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors.

For example, XLA tensors can be added together:

t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)

Or matrix multiplied:

print(t0.mm(t1))

Or used with neural network modules:

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)

Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor

will throw an error since the torch.nn.Linear module is on the CPU.

Running Models on XLA Devices

Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code. The following snippets highlight these lines when running on a single device, multiple devices with XLA multiprocessing, or multiple threads with XLA multithreading.

Running on a Single XLA Device

The following snippet shows a network training on a single XLA device:

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  optimizer.zero_grad()
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()

  xm.optimizer_step(optimizer, barrier=True)

This snippet highlights how easy it is to switch your model to run on XLA. The model definition, dataloader, optimizer and training loop can work on any device. The only XLA-specific code is a couple lines that acquire the XLA device and step the optimizer with a barrier. Calling xm.optimizer_step(optimizer, barrier=True) at the end of each training iteration causes XLA to execute its current graph and update the model’s parameters. See XLA Tensor Deep Dive for more on how XLA creates graphs and runs operations.

Running on Multiple XLA Devices with MultiProcessing

PyTorch/XLA makes it easy to accelerate training by running on multiple XLA devices. The following snippet shows how:

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
  device = xm.xla_device()
  para_loader = pl.ParallelLoader(train_loader, [device])

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in para_loader.per_device_loader(device):
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

There are three differences between this multidevice snippet and the previous single device snippet:

  • xmp.spawn() creates the processes that each run an XLA device.

  • ParallelLoader loads the training data onto each device.

  • xm.optimizer_step(optimizer) no longer needs a barrier. ParallelLoader automatically creates an XLA barrier that evalutes the graph.

The model definition, optimizer definition and training loop remain the same.

See the full multiprocessing example for more on training a network on multiple XLA devices with multiprocessing.

Running on Multiple XLA Devices with MultiThreading

Running on multiple XLA devices using processes (see above) is preferred to using threads. If, however, you want to use threads then PyTorch/XLA has a DataParallel interface. The following snippet shows the same network training with multiple threads:

import torch_xla.core.xla_model as xm
import torch_xla.distributed.data_parallel as dp

devices = xm.get_xla_supported_devices()
model_parallel = dp.DataParallel(MNIST, device_ids=devices)

def train_loop_fn(model, loader, device, context):
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  model.train()
  for _, (data, target) in loader:
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

for epoch in range(1, num_epochs + 1):
  model_parallel(train_loop_fn, train_loader)

The only differences between the multithreading and multiprocessing code are:

  • Multiple devices are acquired in the same process with xm.get_xla_supported_devices().

  • The model is wrapped in dp.DataParallel and passed both the training loop and dataloader.

See the full multithreading example for more on training a network on multiple XLA devices with multithreading.

XLA Tensor Deep Dive

Using XLA tensors and devices requires changing only a few lines of code. But even though XLA tensors act a lot like CPU and CUDA tensors their internals are different. This section describes what makes XLA tensors unique.

XLA Tensors are Lazy

CPU and CUDA tensors launch operations immediately or eagerly. XLA tensors, on the other hand, are lazy. They record operations in a graph until the results are needed. Deferring execution like this lets XLA optimize it. A graph of multiple separate operations might be fused into a single optimized operation, for example.

Lazy execution is generally invisible to the caller. PyTorch/XLA automatically constructs the graphs, sends them to XLA devices, and synchronizes when copying data between an XLA device and the CPU. Inserting a barrier when taking an optimizer step explicitly synchronizes the CPU and the XLA device.

XLA Tensors and bFloat16

PyTorch/XLA can use the bfloat16 datatype when running on TPUs. In fact, PyTorch/XLA handles float types (torch.float and torch.double) differently on TPUs. This behavior is controlled by the XLA_USE_BF16 environment variable:

  • By default both torch.float and torch.double are torch.float on TPUs.

  • If XLA_USE_BF16 is set, then torch.float and torch.double are both bfloat16 on TPUs.

XLA tensors on TPUs will always report their PyTorch datatype regardless of the actual datatype they’re using. This conversion is automatic and opaque. If an XLA tensor on a TPU is moved back to the CPU it will be converted from its actual datatype to its PyTorch datatype.

Memory Layout

The internal data representation of XLA tensors is opaque to the user. They do not expose their storage and they always appear to be contiguous, unlike CPU and CUDA tensors. This allows XLA to adjust a tensor’s memory layout for better performance.

Moving XLA Tensors to and from the CPU

XLA tensors can be moved from the CPU to an XLA device and from an XLA device to the CPU. If a view is moved then the data its viewing is copied to the other device and the view relationship is not preserved. Put another way, once data is copied to another device it has no relationship with its previous device or any tensors on it.

Saving and Loading XLA Tensors

XLA tensors should be moved to the CPU before saving, as in the following snippet:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)

This lets you put the loaded tensors on any available device.

Per the above note on moving XLA tensors to the CPU, care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).

Directly saving XLA tensors is possible but not recommended. XLA tensors are always loaded back to the device they were saved from, and if that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch, is under active development and this behavior may change in the future.

Further Reading

Additional documentation is available at the PyTorch/XLA repo. More examples of running networks on TPUs are available here.

PyTorch/XLA API

xla_model

torch_xla.core.xla_model.xla_device(n=None, devkind=None)[source]

Returns a given instance of an XLA device.

Parameters
  • n (python:int, optional) – The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of devkind will be returned.

  • devkind (string..., optional) – If specified, one of TPU, GPU or CPU (the ‘GPU’ XLA device is currently not implemented).

Returns

A torch.device with the requested instance.

torch_xla.core.xla_model.get_xla_supported_devices(devkind=None, max_devices=None)[source]

Returns a list of supported devices of a given kind.

Parameters
  • devkind (string..., optional) – If specified, one of TPU, GPU or CPU (the ‘GPU’ XLA device is currently not implemented).

  • max_devices (python:int, optional) – The maximum number of devices to be returned of that kind.

Returns

The list of device strings.

torch_xla.core.xla_model.xrt_world_size(defval=1)[source]

Retrieves the number of devices which is taking part of the replication.

Parameters

defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 1

Returns

The number of devices which is taking part of the replication.

torch_xla.core.xla_model.get_ordinal(defval=0)[source]

Retrieves the replication ordinal of the current process.

The ordinals range from 0 to xrt_world_size() minus 1.

Parameters

defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 0

Returns

The replication ordinal of the current process.

torch_xla.core.xla_model.is_master_ordinal()[source]

Checks whether the current process is the master ordinal (0).

Returns

A boolean indicating whether the current process is the master ordinal.

torch_xla.core.xla_model.optimizer_step(optimizer, barrier=False, optimizer_args={})[source]

Run the provided optimizer step and issue the XLA device step computation.

Parameters
  • optimizer (torch.Optimizer) – The torch.Optimizer instance whose step() function needs to be called. The step() function will be called with the optimizer_args named arguments.

  • barrier (bool, optional) – Whether the XLA tensor barrier should be issued in this API. If using the PyTorch XLA ParallelLoader or DataParallel support, this is not necessary as the barrier will be issued by the XLA data loader iterator next() call. Default: False

  • optimizer_args (dict, optional) – Named arguments dictionary for the optimizer.step() call.

Returns

The same value returned by the optimizer.step() call.

distributed

class torch_xla.distributed.parallel_loader.ParallelLoader(loader, devices, batchdim=0, fixed_batch_size=False, loader_prefetch_size=8, device_prefetch_size=4)[source]

Wraps an existing PyTorch DataLoader with background data upload.

Parameters
  • loader (torch.utils.data.DataLoader) – The PyTorch DataLoader to be wrapped.

  • devices (torch.device…) – The list of devices where the data has to be sent. The i-th sample returned by the loader will be sent to devices[i % len(devices)].

  • batchdim (python:int, optional) – The dimension which is holding the batch size. Default: 0

  • fixed_batch_size (bool, optional) – Ensures that all the batch sizes sent to the devices are of the same size. The original loader iteration stops as soon as a not matching batch size is found. Default: False

  • loader_prefetch_size (python:int, optional) – The max capacity of the queue used by the thread which is reading samples from the loader, to be processed by the worker threads which upload data to the devices. Default: 8

  • device_prefetch_size (python:int, optional) – The max size of the per-device queues, where the worker threads deposit tensors which have already been sent to devices. Default: 4

per_device_loader(device)[source]

Retrieves the loader object for the given device.

Parameters

device (torch.device) – The device whole loader is being requested.

Returns

The data loader for the device.

class torch_xla.distributed.data_parallel.DataParallel(network, device_ids=None)[source]

Enable the execution of a model network in replicated mode using threads.

Parameters
  • network (torch.nn.Module or callable) – The model’s network. Either a subclass of torch.nn.Module or a callable returning a subclass of torch.nn.Module.

  • device_ids (string… or torch.device…) – The list of devices on which the replication should happen. If the list is empty, the network will be run on PyTorch CPU device.

__call__(loop_fn, loader, fixed_batch_size=False, batchdim=0)[source]

Runs one EPOCH of training/test.

Parameters
  • loop_fn (callable) – The function which will be called on each thread assigned to each device taking part of the replication. The function will be called with the def loop_fn(model, device_loader, device, context) signature. Where model is the per device network as passed to the DataParallel contructor. The device_loader is the ParallelLoader which will be returning samples for the current device. And the context is a per thread/device context which has the lifetime of the DataParallel object, and can be used by the loop_fn to store objects which needs to persist across different EPOCH.

  • fixed_batch_size (bool, optional) – Argument passed to the ParallelLoader constructor. Default: False

  • batchdim (python:int, optional) – The dimension in the samples returned by the loader holding the batch size. Default: 0

Returns

A list with the values returned by the loop_fn on each device.

torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False)[source]

Enables multi processing based replication.

Parameters
  • fn – The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in args.

  • args – The arguments for fn.

  • nprocs – The number of processes/devices for the replication. At the moment, if specified, can be either 1 or the maximum number of devices.

  • join – Whether the call should block waiting for the completion of the processes which have being spawned.

  • daemon – Whether the processes being spawned should have the daemon flag set (see Python multi-processing API).

Returns

The same object returned by the torch.multiprocessing.spawn API.

utils

class torch_xla.utils.utils.SampleGenerator(data, sample_count)[source]

Iterator which returns multiple samples of a given input data.

Can be used in place of a PyTorch DataLoader to generate synthetic data.

Parameters
  • data – The data which should be returned at each iterator step.

  • sample_count – The maximum number of data samples to be returned.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources