PyTorch on XLA Devices¶
PyTorch runs on XLA devices, like TPUs, with the torch_xla package. This document describes how to run your models on these devices.
Creating an XLA Tensor¶
PyTorch/XLA adds a new xla
device type to PyTorch. This device type works just
like other PyTorch device types. For example, here’s how to create and
print an XLA tensor:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
This code should look familiar. PyTorch/XLA uses the same interface as regular
PyTorch with a few additions. Importing torch_xla
initializes PyTorch/XLA, and
xm.xla_device()
returns the current XLA device. This may be a CPU or TPU
depending on your environment.
XLA Tensors are PyTorch Tensors¶
PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors.
For example, XLA tensors can be added together:
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
Or matrix multiplied:
print(t0.mm(t1))
Or used with neural network modules:
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor
will throw an error since the torch.nn.Linear module is on the CPU.
Running Models on XLA Devices¶
Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code. The following snippets highlight these lines when running on a single device, multiple devices with XLA multiprocessing, or multiple threads with XLA multithreading.
Running on a Single XLA Device¶
The following snippet shows a network training on a single XLA device:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer, barrier=True)
This snippet highlights how easy it is to switch your model to run on XLA. The
model definition, dataloader, optimizer and training loop can work on any device.
The only XLA-specific code is a couple lines that acquire the XLA device and
step the optimizer with a barrier. Calling
xm.optimizer_step(optimizer, barrier=True)
at the end of each training
iteration causes XLA to execute its current graph and update the model’s
parameters. See XLA Tensor Deep Dive for more on
how XLA creates graphs and runs operations.
Running on Multiple XLA Devices with MultiProcessing¶
PyTorch/XLA makes it easy to accelerate training by running on multiple XLA devices. The following snippet shows how:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
device = xm.xla_device()
para_loader = pl.ParallelLoader(train_loader, [device])
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in para_loader.per_device_loader(device):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
There are three differences between this multidevice snippet and the previous single device snippet:
xmp.spawn()
creates the processes that each run an XLA device.ParallelLoader
loads the training data onto each device.xm.optimizer_step(optimizer)
no longer needs a barrier. ParallelLoader automatically creates an XLA barrier that evalutes the graph.
The model definition, optimizer definition and training loop remain the same.
NOTE: It is important to note that, when using multi-processing, the user can start retrieving and accessing XLA devices only from within the target function of
xmp.spawn()
(or any function which hasxmp.spawn()
as parent in the call stack).
See the full multiprocessing example for more on training a network on multiple XLA devices with multiprocessing.
Running on Multiple XLA Devices with MultiThreading¶
Running on multiple XLA devices using processes (see above) is preferred to using
threads. If, however, you want to use threads then PyTorch/XLA has a
DataParallel
interface. The following snippet shows the same network training
with multiple threads:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.data_parallel as dp
devices = xm.get_xla_supported_devices()
model_parallel = dp.DataParallel(MNIST, device_ids=devices)
def train_loop_fn(model, loader, device, context):
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
model.train()
for data, target in loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
for epoch in range(1, num_epochs + 1):
model_parallel(train_loop_fn, train_loader)
The only differences between the multithreading and multiprocessing code are:
Multiple devices are acquired in the same process with
xm.get_xla_supported_devices()
.The model is wrapped in
dp.DataParallel
and passed both the training loop and dataloader.
See the full multithreading example for more on training a network on multiple XLA devices with multithreading.
XLA Tensor Deep Dive¶
Using XLA tensors and devices requires changing only a few lines of code. But even though XLA tensors act a lot like CPU and CUDA tensors their internals are different. This section describes what makes XLA tensors unique.
XLA Tensors are Lazy¶
CPU and CUDA tensors launch operations immediately or eagerly. XLA tensors, on the other hand, are lazy. They record operations in a graph until the results are needed. Deferring execution like this lets XLA optimize it. A graph of multiple separate operations might be fused into a single optimized operation, for example.
Lazy execution is generally invisible to the caller. PyTorch/XLA automatically constructs the graphs, sends them to XLA devices, and synchronizes when copying data between an XLA device and the CPU. Inserting a barrier when taking an optimizer step explicitly synchronizes the CPU and the XLA device.
XLA Tensors and bFloat16¶
PyTorch/XLA can use the
bfloat16
datatype when running on TPUs. In fact, PyTorch/XLA handles float types
(torch.float
and torch.double
) differently on TPUs. This behavior is
controlled by the XLA_USE_BF16
environment variable:
By default both
torch.float
andtorch.double
aretorch.float
on TPUs.If
XLA_USE_BF16
is set, thentorch.float
andtorch.double
are bothbfloat16
on TPUs.If a PyTorch tensor has
torch.bfloat16
data type, this will be directly mapped to the TPUbfloat16
(XLABF16
primitive type).
XLA tensors on TPUs will always report their PyTorch datatype regardless of the actual datatype they’re using. This conversion is automatic and opaque. If an XLA tensor on a TPU is moved back to the CPU it will be converted from its actual datatype to its PyTorch datatype.
Memory Layout¶
The internal data representation of XLA tensors is opaque to the user. They do not expose their storage and they always appear to be contiguous, unlike CPU and CUDA tensors. This allows XLA to adjust a tensor’s memory layout for better performance.
Moving XLA Tensors to and from the CPU¶
XLA tensors can be moved from the CPU to an XLA device and from an XLA device to the CPU. If a view is moved then the data its viewing is copied to the other device and the view relationship is not preserved. Put another way, once data is copied to another device it has no relationship with its previous device or any tensors on it.
Saving and Loading XLA Tensors¶
XLA tensors should be moved to the CPU before saving, as in the following snippet:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
tensors = (t0.cpu(), t1.cpu())
torch.save(tensors, 'tensors.pt')
tensors = torch.load('tensors.pt')
t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
This lets you put the loaded tensors on any available device.
Per the above note on moving XLA tensors to the CPU, care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).
A utility API is provided to save data by taking care of previously moving it to CPU:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
xm.save(model.state_dict(), path)
In case of multple devices, the above API will only save the data for the master device ordinal (0).
Directly saving XLA tensors is possible but not recommended. XLA tensors are always loaded back to the device they were saved from, and if that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch, is under active development and this behavior may change in the future.
Further Reading¶
Additional documentation is available at the PyTorch/XLA repo. More examples of running networks on TPUs are available here.
PyTorch/XLA API¶
xla_model¶
-
torch_xla.core.xla_model.
xla_device
(n=None, devkind=None)[source]¶ Returns a given instance of an XLA device.
- Parameters
n (python:int, optional) – The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of devkind will be returned.
devkind (string..., optional) – If specified, one of TPU, GPU or CPU (the ‘GPU’ XLA device is currently not implemented).
- Returns
A torch.device with the requested instance.
-
torch_xla.core.xla_model.
get_xla_supported_devices
(devkind=None, max_devices=None)[source]¶ Returns a list of supported devices of a given kind.
- Parameters
devkind (string..., optional) – If specified, one of TPU, GPU or CPU (the ‘GPU’ XLA device is currently not implemented).
max_devices (python:int, optional) – The maximum number of devices to be returned of that kind.
- Returns
The list of device strings.
-
torch_xla.core.xla_model.
xrt_world_size
(defval=1)[source]¶ Retrieves the number of devices which is taking part of the replication.
- Parameters
defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 1
- Returns
The number of devices which is taking part of the replication.
-
torch_xla.core.xla_model.
get_ordinal
(defval=0)[source]¶ Retrieves the replication ordinal of the current process.
The ordinals range from 0 to xrt_world_size() minus 1.
- Parameters
defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 0
- Returns
The replication ordinal of the current process.
-
torch_xla.core.xla_model.
get_local_ordinal
(defval=0)[source]¶ Retrieves the replication local ordinal of the current process.
The local ordinals range from 0 to the number of local devices minus 1.
- Parameters
defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 0
- Returns
The replication local ordinal of the current process.
-
torch_xla.core.xla_model.
is_master_ordinal
(local=True)[source]¶ Checks whether the current process is the master ordinal (0).
- Parameters
local (bool) – Whether the local or global master ordinal should be checked. In case of multi-host replication, there is only one global master ordinal (host 0, device 0), while there are NUM_HOSTS local master ordinals. Default: True
- Returns
A boolean indicating whether the current process is the master ordinal.
-
torch_xla.core.xla_model.
all_reduce
(reduce_type, inputs, scale=1.0, groups=[])[source]¶ Perform an inplace reduce operation on the input tensors.
- Parameters
reduce_type (string) – One of
sum
,mul
,and
,or
,min
andmax
.inputs (list) – List of tensors to perform the all reduce op to.
scale (python:float) – A default scaling value to be applied after the reduce. Default: 1.0
groups (list) – Reserved.
-
torch_xla.core.xla_model.
add_step_closure
(closure, args=())[source]¶ Adds a closure to the list of the ones to be run at the end of the step.
Many times during model training there is the need to print/report (print to console, post to tensorboard, etc…) information which require the content of intermediary tensors to be inspected. Inspecting different tensors content in different points of the model code requires many executions and typically causes performance issues. Adding a step closure will ensure that it will be run after the barrier, when all the live tensors will be already materialized to device data. Live tensors which will include the ones captured by the closure arguments. So using add_step_closure() will ensure a single execution will be performed, even when multiple closures are queued, requiring multiple tensors to be inspected. Step closures will be run sequentially in the order they have been queued. Note that even though using this API the execution will be optimized, it is advised to throttle the printing/reporting events once every N steps.
- Parameters
closure (callable) – The function to be called.
args (tuple) – The arguments to be passed to the closure.
-
torch_xla.core.xla_model.
optimizer_step
(optimizer, barrier=False, optimizer_args={})[source]¶ Run the provided optimizer step and issue the XLA device step computation.
- Parameters
optimizer (
torch.Optimizer
) – The torch.Optimizer instance whose step() function needs to be called. The step() function will be called with the optimizer_args named arguments.barrier (bool, optional) – Whether the XLA tensor barrier should be issued in this API. If using the PyTorch XLA ParallelLoader or DataParallel support, this is not necessary as the barrier will be issued by the XLA data loader iterator next() call. Default: False
optimizer_args (dict, optional) – Named arguments dictionary for the optimizer.step() call.
- Returns
The same value returned by the optimizer.step() call.
-
torch_xla.core.xla_model.
save
(data, file_or_path, master_only=True, global_master=False)[source]¶ Saves the input data into a file.
The saved data is transfered to PyTorch CPU device before being saved, so a following torch.load() will load CPU data.
- Parameters
data – The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, …).
file_or_path – The destination for the data saving operation. Either a file path or a Python file object. If master_only is
False
the path or file objects must point to different destinations as otherwise all the writes from the same host will override each other.master_only (bool, optional) – Whether only the master device should save the data. If False, the file_or_path argument should be a different file or path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True
global_master (bool, optional) – When
master_only
isTrue
this flag controls whether every host’s master (ifglobal_master
isFalse
) saves the content, or only the global master (ordinal 0). Default: False
distributed¶
-
class
torch_xla.distributed.parallel_loader.
ParallelLoader
(loader, devices, batchdim=0, fixed_batch_size=False, loader_prefetch_size=8, device_prefetch_size=4)[source]¶ Wraps an existing PyTorch DataLoader with background data upload.
- Parameters
loader (
torch.utils.data.DataLoader
) – The PyTorch DataLoader to be wrapped.devices (torch.device…) – The list of devices where the data has to be sent. The i-th sample returned by the loader will be sent to devices[i % len(devices)].
batchdim (python:int, optional) – The dimension which is holding the batch size. Default: 0
fixed_batch_size (bool, optional) – Ensures that all the batch sizes sent to the devices are of the same size. The original loader iteration stops as soon as a not matching batch size is found. Default: False
loader_prefetch_size (python:int, optional) – The max capacity of the queue used by the thread which is reading samples from the loader, to be processed by the worker threads which upload data to the devices. Default: 8
device_prefetch_size (python:int, optional) – The max size of the per-device queues, where the worker threads deposit tensors which have already been sent to devices. Default: 4
-
per_device_loader
(device)[source]¶ Retrieves the loader iterator object for the given device.
- Parameters
device (torch.device) – The device whole loader is being requested.
- Returns
The loader iterator object for the device. This is not a torch.utils.data.DataLoader interface, but a Python iterator which returns the same tensor data structure as returned by the wrapped torch.utils.data.DataLoader, but residing on XLA devices.
-
class
torch_xla.distributed.data_parallel.
DataParallel
(network, device_ids=None)[source]¶ Enable the execution of a model network in replicated mode using threads.
- Parameters
network (
torch.nn.Module
or callable) – The model’s network. Either a subclass of torch.nn.Module or a callable returning a subclass of torch.nn.Module.device_ids (string… or
torch.device
…) – The list of devices on which the replication should happen. If the list is empty, the network will be run on PyTorch CPU device.
-
__call__
(loop_fn, loader, fixed_batch_size=False, batchdim=0)[source]¶ Runs one EPOCH of training/test.
- Parameters
loop_fn (callable) – The function which will be called on each thread assigned to each device taking part of the replication. The function will be called with the def loop_fn(model, device_loader, device, context) signature. Where model is the per device network as passed to the DataParallel contructor. The device_loader is the ParallelLoader which will be returning samples for the current device. And the context is a per thread/device context which has the lifetime of the DataParallel object, and can be used by the loop_fn to store objects which needs to persist across different EPOCH.
fixed_batch_size (bool, optional) – Argument passed to the ParallelLoader constructor. Default: False
batchdim (python:int, optional) – The dimension in the samples returned by the loader holding the batch size. Default: 0
- Returns
A list with the values returned by the loop_fn on each device.
-
torch_xla.distributed.xla_multiprocessing.
spawn
(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]¶ Enables multi processing based replication.
- Parameters
fn (callable) – The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in args.
args (tuple) – The arguments for fn. Default: Empty tuple
nprocs (python:int) – The number of processes/devices for the replication. At the moment, if specified, can be either 1 or the maximum number of devices.
join (bool) – Whether the call should block waiting for the completion of the processes which have being spawned. Default: True
daemon (bool) – Whether the processes being spawned should have the daemon flag set (see Python multi-processing API). Default: False
start_method (string) – The Python multiprocessing process creation mathod. Default: spawn
- Returns
The same object returned by the torch.multiprocessing.spawn API. If nprocs is 1 the fn function will be called directly, and the API will not return.
utils¶
-
class
torch_xla.utils.utils.
SampleGenerator
(data, sample_count)[source]¶ Iterator which returns multiple samples of a given input data.
Can be used in place of a PyTorch DataLoader to generate synthetic data.
- Parameters
data – The data which should be returned at each iterator step.
sample_count – The maximum number of data samples to be returned.