• Docs >
  • PyTorch on XLA Devices
Shortcuts

PyTorch on XLA Devices

PyTorch runs on XLA devices, like TPUs, with the torch_xla package. This document describes how to run your models on these devices.

Creating an XLA Tensor

PyTorch/XLA adds a new xla device type to PyTorch. This device type works just like other PyTorch device types. For example, here’s how to create and print an XLA tensor:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing torch_xla initializes PyTorch/XLA, and xm.xla_device() returns the current XLA device. This may be a CPU or TPU depending on your environment.

XLA Tensors are PyTorch Tensors

PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors.

For example, XLA tensors can be added together:

t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)

Or matrix multiplied:

print(t0.mm(t1))

Or used with neural network modules:

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)

Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor

will throw an error since the torch.nn.Linear module is on the CPU.

Running Models on XLA Devices

Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code. The following snippets highlight these lines when running on a single device and multiple devices with XLA multiprocessing.

Running on a Single XLA Device

The following snippet shows a network training on a single XLA device:

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  optimizer.zero_grad()
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()

  optimizer.step()
  xm.mark_step()

This snippet highlights how easy it is to switch your model to run on XLA. The model definition, dataloader, optimizer and training loop can work on any device. The only XLA-specific code is a couple lines that acquire the XLA device and mark the step. Calling xm.mark_step() at the end of each training iteration causes XLA to execute its current graph and update the model’s parameters. See XLA Tensor Deep Dive for more on how XLA creates graphs and runs operations.

Running on Multiple XLA Devices with MultiProcessing

PyTorch/XLA makes it easy to accelerate training by running on multiple XLA devices. The following snippet shows how:

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
  device = xm.xla_device()
  mp_device_loader = pl.MpDeviceLoader(train_loader, device)

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in mp_device_loader:
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

There are three differences between this multidevice snippet and the previous single device snippet:

  • xmp.spawn() creates the processes that each run an XLA device.

  • MpDeviceLoader loads the training data onto each device.

  • xm.optimizer_step(optimizer) consolidates the gradients between cores and issues the XLA device step computation.

The model definition, optimizer definition and training loop remain the same.

NOTE: It is important to note that, when using multi-processing, the user can start retrieving and accessing XLA devices only from within the target function of xmp.spawn() (or any function which has xmp.spawn() as parent in the call stack).

See the full multiprocessing example for more on training a network on multiple XLA devices with multiprocessing.

XLA Tensor Deep Dive

Using XLA tensors and devices requires changing only a few lines of code. But even though XLA tensors act a lot like CPU and CUDA tensors their internals are different. This section describes what makes XLA tensors unique.

XLA Tensors are Lazy

CPU and CUDA tensors launch operations immediately or eagerly. XLA tensors, on the other hand, are lazy. They record operations in a graph until the results are needed. Deferring execution like this lets XLA optimize it. A graph of multiple separate operations might be fused into a single optimized operation, for example.

Lazy execution is generally invisible to the caller. PyTorch/XLA automatically constructs the graphs, sends them to XLA devices, and synchronizes when copying data between an XLA device and the CPU. Inserting a barrier when taking an optimizer step explicitly synchronizes the CPU and the XLA device.

XLA Tensors and bFloat16

PyTorch/XLA can use the bfloat16 datatype when running on TPUs. In fact, PyTorch/XLA handles float types (torch.float and torch.double) differently on TPUs. This behavior is controlled by the XLA_USE_BF16 environment variable:

  • By default both torch.float and torch.double are torch.float on TPUs.

  • If XLA_USE_BF16 is set, then torch.float and torch.double are both bfloat16 on TPUs.

  • If a PyTorch tensor has torch.bfloat16 data type, this will be directly mapped to the TPU bfloat16 (XLA BF16 primitive type).

XLA tensors on TPUs will always report their PyTorch datatype regardless of the actual datatype they’re using. This conversion is automatic and opaque. If an XLA tensor on a TPU is moved back to the CPU it will be converted from its actual datatype to its PyTorch datatype.

Memory Layout

The internal data representation of XLA tensors is opaque to the user. They do not expose their storage and they always appear to be contiguous, unlike CPU and CUDA tensors. This allows XLA to adjust a tensor’s memory layout for better performance.

Moving XLA Tensors to and from the CPU

XLA tensors can be moved from the CPU to an XLA device and from an XLA device to the CPU. If a view is moved then the data its viewing is copied to the other device and the view relationship is not preserved. Put another way, once data is copied to another device it has no relationship with its previous device or any tensors on it.

Saving and Loading XLA Tensors

XLA tensors should be moved to the CPU before saving, as in the following snippet:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)

This lets you put the loaded tensors on any available device.

Per the above note on moving XLA tensors to the CPU, care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).

A utility API is provided to save data by taking care of previously moving it to CPU:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

xm.save(model.state_dict(), path)

In case of multiple devices, the above API will only save the data for the master device ordinal (0).

In case where memory is limited compare to the size of the model parameters, an API is provided that reduces the memory pressure on the host:

import torch_xla.utils.serialization as xser

xser.save(model.state_dict(), path)

Such API streams XLA tensors to CPU one at a time, reducing the amount of host memory used, but it requires a matching load API to restore:

import torch_xla.utils.serialization as xser

state_dict = xser.load(path)
model.load_state_dict(state_dict)

Directly saving XLA tensors is possible but not recommended. XLA tensors are always loaded back to the device they were saved from, and if that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch, is under active development and this behavior may change in the future.

Further Reading

Additional documentation is available at the PyTorch/XLA repo. More examples of running networks on TPUs are available here.

PyTorch/XLA API

xla_model

torch_xla.core.xla_model.xla_device(n=None, devkind=None)[source]

Returns a given instance of an XLA device.

Parameters
  • n (python:int, optional) – The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of devkind will be returned.

  • devkind (string..., optional) – If specified, one of TPU, GPU or CPU.

Returns

A torch.device with the requested instance.

torch_xla.core.xla_model.get_xla_supported_devices(devkind=None, max_devices=None)[source]

Returns a list of supported devices of a given kind.

Parameters
  • devkind (string..., optional) – If specified, one of TPU, GPU or CPU (the ‘GPU’ XLA device is currently not implemented).

  • max_devices (python:int, optional) – The maximum number of devices to be returned of that kind.

Returns

The list of device strings.

torch_xla.core.xla_model.xla_device_hw(device)[source]

Returns the hardware type of the given device.

Parameters

device (string or torch.device) – The xla device that will be mapped to the real device.

Returns

A string representation of the hardware type (CPU, TPU, GPU) of the given device.

torch_xla.core.xla_model.get_ordinal(defval=0)[source]

Retrieves the replication ordinal of the current process.

The ordinals range from 0 to xrt_world_size() minus 1.

Parameters

defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 0

Returns

The replication ordinal of the current process.

torch_xla.core.xla_model.get_local_ordinal(defval=0)[source]

Retrieves the replication local ordinal of the current process.

The local ordinals range from 0 to the number of local devices minus 1.

Parameters

defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 0

Returns

The replication local ordinal of the current process.

torch_xla.core.xla_model.is_master_ordinal(local=True)[source]

Checks whether the current process is the master ordinal (0).

Parameters

local (bool) – Whether the local or global master ordinal should be checked. In case of multi-host replication, there is only one global master ordinal (host 0, device 0), while there are NUM_HOSTS local master ordinals. Default: True

Returns

A boolean indicating whether the current process is the master ordinal.

torch_xla.core.xla_model.xrt_world_size(defval=1)[source]

Retrieves the number of devices which is taking part of the replication.

Parameters

defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 1

Returns

The number of devices which is taking part of the replication.

torch_xla.core.xla_model.all_reduce(reduce_type, inputs, scale=1.0, groups=None, cctx=None)[source]

Performs an inplace reduce operation on the input tensor(s).

Parameters
  • reduce_type (string) – One of REDUCE_SUM, REDUCE_MUL, REDUCE_AND, REDUCE_OR, REDUCE_MIN and REDUCE_MIN.

  • inputs – Either a single torch.Tensor or a list of torch.Tensor to perform the all reduce op to.

  • scale (python:float) – A default scaling value to be applied after the reduce. Default: 1.0

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

Returns

If a single torch.Tensor is passed, the return value is a torch.Tensor holding the reduced value (across the replicas). If a list/tuple is passed, this function performs an inplace all-reduce op on the input tensors, and returns the list/tuple itself.

torch_xla.core.xla_model.all_gather(value, dim=0, groups=None)[source]

Performs an all-gather operation along a given dimension.

Parameters
  • value (torch.Tensor) – The input tensor.

  • dim (python:int) – The gather dimension. Default: 0

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_gather() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

Returns

A tensor which has, in the dim dimension, all the values from the participating replicas.

torch_xla.core.xla_model.all_to_all(value, split_dimension, concat_dimension, split_count, groups=None)[source]

Performs an XLA AllToAll() operation on the input tensor.

See: https://www.tensorflow.org/xla/operation_semantics#alltoall

Parameters
  • value (torch.Tensor) – The input tensor.

  • split_dimension (python:int) – The dimension upon which the split should happen.

  • concat_dimension (python:int) – The dimension upon which the concat should happen.

  • split_count (python:int) – The split count.

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

Returns

The result torch.Tensor of the all_to_all() operation.

torch_xla.core.xla_model.add_step_closure(closure, args=())[source]

Adds a closure to the list of the ones to be run at the end of the step.

Many times during model training there is the need to print/report (print to console, post to tensorboard, etc…) information which require the content of intermediary tensors to be inspected. Inspecting different tensors content in different points of the model code requires many executions and typically causes performance issues. Adding a step closure will ensure that it will be run after the barrier, when all the live tensors will be already materialized to device data. Live tensors which will include the ones captured by the closure arguments. So using add_step_closure() will ensure a single execution will be performed, even when multiple closures are queued, requiring multiple tensors to be inspected. Step closures will be run sequentially in the order they have been queued. Note that even though using this API the execution will be optimized, it is advised to throttle the printing/reporting events once every N steps.

Parameters
  • closure (callable) – The function to be called.

  • args (tuple) – The arguments to be passed to the closure.

torch_xla.core.xla_model.wait_device_ops(devices=[])[source]

Waits for all the async operations on the given devices to complete.

Parameters

devices (string..., optional) – The devices whose async ops need to be waited for. If empty, all the local devices will be waited for.

torch_xla.core.xla_model.optimizer_step(optimizer, barrier=False, optimizer_args={}, groups=None)[source]

Run the provided optimizer step and issue the XLA device step computation.

Parameters
  • optimizer (torch.Optimizer) – The torch.Optimizer instance whose step() function needs to be called. The step() function will be called with the optimizer_args named arguments.

  • barrier (bool, optional) – Whether the XLA tensor barrier should be issued in this API. If using the PyTorch XLA ParallelLoader or DataParallel support, this is not necessary as the barrier will be issued by the XLA data loader iterator next() call. Default: False

  • optimizer_args (dict, optional) – Named arguments dictionary for the optimizer.step() call.

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

Returns

The same value returned by the optimizer.step() call.

torch_xla.core.xla_model.save(data, file_or_path, master_only=True, global_master=False)[source]

Saves the input data into a file.

The saved data is transferred to PyTorch CPU device before being saved, so a following torch.load() will load CPU data. Care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).

Parameters
  • data – The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, …).

  • file_or_path – The destination for the data saving operation. Either a file path or a Python file object. If master_only is False the path or file objects must point to different destinations as otherwise all the writes from the same host will override each other.

  • master_only (bool, optional) – Whether only the master device should save the data. If False, the file_or_path argument should be a different file or path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True

  • global_master (bool, optional) – When master_only is True this flag controls whether every host’s master (if global_master is False) saves the content, or only the global master (ordinal 0). Default: False

torch_xla.core.xla_model.rendezvous(tag, payload=b'', replicas=[])[source]

Waits for all the mesh clients to reach the named rendezvous.

Parameters
  • tag (string) – The name of the rendezvous to join.

  • payload (bytes, optional) – The payload to be sent to the rendezvous.

  • replicas (list, python:int) – The replica ordinals taking part of the rendezvous. Empty means all replicas in the mesh. Default: []

Returns

The payloads exchanged by all the other cores, with the payload of core ordinal i at position i in the returned tuple.

torch_xla.core.xla_model.do_on_ordinals(target, data=(), ordinals=(0, ))[source]

Runs a function only on a given set of ordinals.

Parameters
  • target (callable) – The function to be run on ordinals.

  • data – Any input data for the target function which contains tensors. All the XLA tensors used by the target function must be passed in this argument. Every other data used by the function can be captured by the Python interpreter as usual. Default: ()

  • ordinals (list, python:int) – The list/set of ordinals where the target function should run. Default: (0,)

Returns

In the ordinals that ran the target function, the function return value, otherwise None.

torch_xla.core.xla_model.mesh_reduce(tag, data, reduce_fn)[source]

Performs an out-of-graph client mesh reduction.

Parameters
  • tag (string) – The name of the rendezvous to join.

  • data – The data to be reduced. The reduce_fn callable will receive a list with the copies of the same data coming from all the mesh client processes (one per core).

  • reduce_fn (callable) – A function which receives a list of data-like objects and returns the reduced result.

Returns

The reduced value.

torch_xla.core.xla_model.set_rng_state(seed, device=None)[source]

Sets the random number generator state.

Parameters
  • seed (python:integer) – The state to be set.

  • device (string, optional) – The device where the RNG state needs to be set. If missing the default device seed will be set.

torch_xla.core.xla_model.get_rng_state(device=None)[source]

Gets the current running random number generator state.

Parameters

device (string, optional) – The device whose RNG state needs to be retrieved. If missing the default device seed will be set.

Returns

The RNG state, as integer.

torch_xla.core.xla_model.get_memory_info(device)[source]

Retrieves the device memory information.

Parameters

device (string) – The device whose memory information are requested.

Returns

A dictionary with kb_free (free memory in KB) and kb_total (total memory in KB) keys.

torch_xla.core.functions.all_reduce(reduce_type, value, scale=1.0, groups=None)[source]

Performs an inplace reduce operation on the input tensor.

This is the same as xm.all_reduce() but supports autograd differentiation.

Parameters
  • reduce_type (string) – One of REDUCE_SUM, REDUCE_MUL, REDUCE_AND, REDUCE_OR, REDUCE_MIN and REDUCE_MIN.

  • value (torch.Tensor) – The to perform the all reduce op to.

  • scale (python:float) – A default scaling value to be applied after the reduce. Default: 1.0

  • groups (list, optional) –

    A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

Returns

The reduced value across the selected replicas.

torch_xla.core.functions.all_gather(value, dim=0)[source]

Performs an all-gather operation along a given dimension.

This is the same as xm.all_gather() but supports autograd differentiation.

Parameters
  • value (torch.Tensor) – The input tensor.

  • dim (python:int) – The gather dimension. Default: 0

Returns

A tensor which has, in the dim dimension, all the values from the participating replicas.

torch_xla.core.functions.nms(boxes, scores, score_threshold, iou_threshold, output_size)[source]

Performs a Non Maximal Suppression operation.

Parameters
  • boxes (torch.Tensor) – A torch.Tensor of shape [N, 4] listing the boxes coordinates in (y0, x0, y1, x1) form.

  • scores (torch.Tensor) – A torch.Tensor of shape [N] listing the scores of each box.

  • score_threshold (torch.Tensor) – The minimum score for a box to qualify as valid.

  • iou_threshold (torch.Tensor) – The minimum IOU (Intersection Over Union) score to trigger overlap logic.

  • output_size (python:int) – The maximum number of returned indices (must be lower or equal to N).

Returns

A tuple of torch.Tensor with the first element being the selected box indices, and the second element being the number of valid boxes.

distributed

class torch_xla.distributed.parallel_loader.ParallelLoader(loader, devices, batchdim=0, fixed_batch_size=False, loader_prefetch_size=8, device_prefetch_size=4)[source]

Wraps an existing PyTorch DataLoader with background data upload.

Parameters
  • loader (torch.utils.data.DataLoader) – The PyTorch DataLoader to be wrapped.

  • devices (torch.device…) – The list of devices where the data has to be sent. The i-th sample returned by the loader will be sent to devices[i % len(devices)].

  • batchdim (python:int, optional) – The dimension which is holding the batch size. Default: 0

  • fixed_batch_size (bool, optional) – Ensures that all the batch sizes sent to the devices are of the same size. The original loader iteration stops as soon as a not matching batch size is found. Default: False

  • loader_prefetch_size (python:int, optional) – The max capacity of the queue used by the thread which is reading samples from the loader, to be processed by the worker threads which upload data to the devices. Default: 8

  • device_prefetch_size (python:int, optional) – The max size of the per-device queues, where the worker threads deposit tensors which have already been sent to devices. Default: 4

per_device_loader(device)[source]

Retrieves the loader iterator object for the given device.

Parameters

device (torch.device) – The device whole loader is being requested.

Returns

The loader iterator object for the device. This is not a torch.utils.data.DataLoader interface, but a Python iterator which returns the same tensor data structure as returned by the wrapped torch.utils.data.DataLoader, but residing on XLA devices.

torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]

Enables multi processing based replication.

Parameters
  • fn (callable) – The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in args.

  • args (tuple) – The arguments for fn. Default: Empty tuple

  • nprocs (python:int) – The number of processes/devices for the replication. At the moment, if specified, can be either 1 or the maximum number of devices.

  • join (bool) – Whether the call should block waiting for the completion of the processes which have being spawned. Default: True

  • daemon (bool) – Whether the processes being spawned should have the daemon flag set (see Python multi-processing API). Default: False

  • start_method (string) – The Python multiprocessing process creation method. Default: spawn

Returns

The same object returned by the torch.multiprocessing.spawn API. If nprocs is 1 the fn function will be called directly, and the API will not return.

class torch_xla.distributed.xla_multiprocessing.MpModelWrapper(model)[source]

Wraps a model to minimize host memory usage when fork method is used.

This class should be used together with the spawn(…, start_method=’fork’) API to minimize the use of host memory. Instead of creating models on each multiprocessing process, hence replicating the model’s initial host memory, the model is created once at global scope, and then moved into each device inside the spawn() target function. Example:

WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork())

def _mp_fn(index, ...):
  device = xm.xla_device()
  model = WRAPPED_MODEL.to(device)
  ...

xmp.spawn(_mp_fn, ..., start_method='fork')

This method has two advantages. First if uses only one copy of the memory pages to host the original model weights, and second it serializes the move of the wrapped model into each device, by lowering the load onto the system memory during the process.

to(device)[source]

Retrieves the model moved onto the specified device.

Parameters

device (torch.device) – The device where the model should be moved onto.

Returns

The model on the specified device.

class torch_xla.distributed.xla_multiprocessing.MpSerialExecutor[source]

Utility to run a function in a serialized fashion among multi-core processes.

Example:

# At global scope.
SERIAL_EXEC = xmp.MpSerialExecutor()

def load_dataset(path):
  return maybe_download_and_load(path)

def _mp_fn(index, ...):
  # Avoid all cores downloading the same data with the serial executor.
  dataset = SERIAL_EXEC.run(lambda: load_dataset('/tmp/mnist-data'))
  ...

xmp.spawn(_mp_fn, ...)
run(fn)[source]

Runs the provided function serialized WRT each per-core process.

Parameters

fn (callable) – The function to run in a serialized fashion.

Returns

The fn return value.

utils

class torch_xla.utils.tf_record_reader.TfRecordReader(path, compression='', buffer_size=16777216, transforms=None)[source]

Reads TfRecords or TfExamples.

Parameters
  • path (string) – The path to the file containing TfRecords.

  • compression (string, optional) – The compression type. The empty string for no compression, otherwise ZLIB or GZIP. Default: No compression.

  • buffer_size (python:int, optional) – The size of the buffer to be used to read TfRecords. Default: 16 * 1024 * 1024

  • transforms (dict, optional) – A dictionary with the key matching the TfExample label name, and value which is either a callable which will be called to tranform the matching tensor data, or STR for string conversion.

class torch_xla.utils.utils.SampleGenerator(data, sample_count)[source]

Iterator which returns multiple samples of a given input data.

Can be used in place of a PyTorch DataLoader to generate synthetic data.

Parameters
  • data – The data which should be returned at each iterator step.

  • sample_count – The maximum number of data samples to be returned.

class torch_xla.utils.utils.DataWrapper[source]

Utility class to wrap data structures to be sent to device.

torch_xla.utils.serialization.save(data, path, master_only=True, global_master=False)[source]

Saves the input data into a file.

The saved data is transferred to PyTorch CPU device before being saved, so a following torch.load() will load CPU data. Care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).

Parameters
  • data – The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, …).

  • path – The destination file for the data saving operation. If master_only is False the path must point to different destinations as otherwise all the writes from the same host will override each other.

  • master_only (bool, optional) – Whether only the master device should save the data. If False, the path argument should be a different path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True

  • global_master (bool, optional) – When master_only is True this flag controls whether every host’s master (if global_master is False) saves the content, or only the global master (ordinal 0). Default: False

torch_xla.utils.serialization.load(path)[source]

Loads data previously saved with the save() API.

Parameters

path (str) – The path passed to the save() API.

Returns

The loaded data.

torch_xla.utils.gcsfs.open(path, mode='r', encoding=None)[source]

Opens a Google Cloud Storage (GCS) file for reading or writing.

Parameters
  • path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where BUCKET_NAME is the name of the GCS bucket, and PATH is a / delimited path.

  • mode (string, optional) – The open mode, similar to the open() API. Default: ‘r’

  • encoding (string, optional) – The character encoding to be used to decode bytes into strings when opening in text mode. Default: None

Returns

The GCS file object.

torch_xla.utils.gcsfs.list(path)[source]

Lists the content of a GCS bucket.

Parameters

path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where BUCKET_NAME is the name of the GCS bucket, and PATH is a / delimited path.

Returns

A list of GcsBlob objects.

torch_xla.utils.gcsfs.stat(path)[source]

Fetches the information of a GCS file.

Parameters

path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where BUCKET_NAME is the name of the GCS bucket, and PATH is a / delimited path.

Returns

A GcsBlob object.

torch_xla.utils.gcsfs.remove(path)[source]

Removes a GCS blob.

Parameters

path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where BUCKET_NAME is the name of the GCS bucket, and PATH is a / delimited path.

torch_xla.utils.gcsfs.rmtree(path)[source]

Removes all the GCS blobs within a given path.

Parameters

path (string) –

The GCS path of the file pattern or folder. Must be “gs://BUCKET_NAME/PATH” where BUCKET_NAME is the name of the GCS

bucket, and PATH is a / delimited path.

torch_xla.utils.gcsfs.read(path)[source]

Reads the whole content of a GCS blob.

Parameters

path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where BUCKET_NAME is the name of the GCS bucket, and PATH is a / delimited path.

Returns

The bytes stored within the GCS blob.

torch_xla.utils.gcsfs.write(path, content)[source]

Write a string/bytes or file into a GCS blob.

Parameters
  • path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where BUCKET_NAME is the name of the GCS bucket, and PATH is a / delimited path.

  • content (string, bytes or file object) – The content to be written into path.

torch_xla.utils.gcsfs.generic_open(path, mode='r', encoding=None)[source]

Opens a file (GCS or not) for reding or writing.

Parameters
  • path (string) –

    The path of the file to be opened. If a GCS path, it must be “gs://BUCKET_NAME/PATH” where BUCKET_NAME is the name of the GCS

    bucket, and PATH is a / delimited path.

  • mode (string, optional) – The open mode, similar to the open() API. Default: ‘r’

  • encoding (string, optional) – The character encoding to be used to decode bytes into strings when opening in text mode. Default: None

Returns

The opened file object.

torch_xla.utils.gcsfs.generic_read(path)[source]

Reads the whole content of the provided location.

Parameters

path (string) – The GCS path or local path to be read.

Returns

The bytes stored within the GCS blob or local file.

torch_xla.utils.gcsfs.generic_write(output_string, path, makedirs=False)[source]

Write a string/bytes or file into a GCS blob or local disk.

Depending on the path passed in, this API can write to local or GCS file. Checks if the path starts with the ‘gs://’ prefix, and uses open otherwise.

Parameters
  • output_string (string) – The string to be written to the output.

  • path (string) – The GCS path or local path of the output.

  • makedirs (bool) – Whether the path parent folders should be created if missing. Default: False

torch_xla.utils.gcsfs.is_gcs_path(path)[source]

Checks whether a path is a GCS path.

Parameters

path (string) – The path to be checked.

Returns

Whether path is a GCS path.

class torch_xla.utils.cached_dataset.CachedDataset(data_set, path, max_files_per_folder=1000, compress=True)[source]

Wraps an existing torch.utils.data.Dataset by providing file caching.

The CachedDataset can be used to trade the CPU/RAM resources required to process a raw dataset, with storage/network resources. Example:

train_dataset = datasets.MNIST(
    FLAGS.datadir,
    train=True,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))]))
train_dataset = CachedDataset(train_dataset, FLAGS.dscache_dir)

The CachedDataset will transparently cache the original Dataset samples, so that every run after the first, will not trigger any more CPU/RAM usage related to the raw samples processing. Once a CachedDataset is fully cached, it can be exported (ie, tar.gz) and used in different machines. Just unpack the tar.gz and pass None as original Dataset: Example:

train_dataset = CachedDataset(None, FLAGS.dscache_dir)

To fully cache CachedDataset just run the warmup() API. A CachedDataset saved on GCS has the advantage to be able to be used from different machines without explicit exporting.

Parameters
  • data_set (torch.utils.data.Dataset) – The raw torch.utils.data.Dataset to be cached. It can be set to None in case all the input samples are stored within the path folder.

  • path (string) – The path where the dataset samples should be stored/loaded. The path needs to be writeable, unless all the samples are already stored. The path can be a GCS path (prefixed with gs://).

  • max_files_per_folder (python:int) – The maximum amount of files to be stored within a single folder. If data_set is None this value is ignored and taken from the cached metadata. Default: 1000

  • compress (bool) – Whether the saved samples should be compressed. Compression saves space at the expense of CPU required to compress/decompress. If data_set is None this value is ignored and taken from the cached metadata. Default: True

test

OP Lowering Guide

Background

PyTorch wraps the C++ ATen tensor library that offers a wide range of operations implemented on GPU and CPU. Pytorch/XLA is a PyTorch extension; one of its purposes is to convert PyTorch operations to XLA operations. Lowering defines a process of converting a higher-level representation to a lower-level representation. In this document, I will refer to the process of converting PyTorch operation to XLA operation as the lowering. XLA Compiler will also lower XlaOp to HLO, but that’s beyond the scope of this documentation. We will forward operations that we haven’t provided an XLA lowering yet to CPU and call ATen implementations. Operations that are forwarded to the CPU will cause a significant slowdown. We must lower all operations used in the model to achieve the best performance.

Before you start

You should follow the instructions in here to install required dependencies and build pytorch and pytorch/XLA from the source. You do not need access to TPU to implement the lowering. It is recommended to experiment on a workstation and configure it to use XLA:CPU.

Understanding the operation

You can find the definition of the C++ ATen operations in native_functions.yaml. After you build Pytorch/XLA from source, you will also find our default implementation (forward to PyTorch native CPU) in xla/torch_xla/csrc/aten_xla_type_default.h/cpp. Pytorch operations can usually be mapped to PyTorch tensor api easily. If that is not the case searching the PyTorch native implementation under PyTorch repo is recommended. The goal is to lower the PyTorch operations into a sequence of XLA operations defined in here.

File structure

All file mentioned below lives under the xla/torch_xla/csrc folder

  1. aten_xla_type_default.h/.cpp are auto-generated by this script and contain our default implementation of the PyTorch operations. Functions in here will be used if lowering is not explicitly defined in aten_xla_type.cpp.

  2. aten_xla_type.h/.cpp are entry points of PyTorch to the pytorch_xla world. We need to copy operation declarations from aten_xla_type_default.h to here and construct XLATensor using the input at::Tensor and other parameters. The resulting XLATensor needs to be converted back to the at::Tensor before returning to the PyTorch world.

  3. tensor.h contains the XLATensor declarations. These declarations are one to one mapping of the at::Tensor nodes we declared in aten_xla_type.h

  4. tensor_methods.cpp contains the implementation of XLATensor node defined in tensor.h. We constructed the corresponding ir::op from the parameter’s ir::Value and wrapped it inside a XLATensor. Ir stands for intermediate representation.

  5. ops/ directory contains all ir::ops declaration and definition. Smaller nodes can be put in ops/ops.h/.cpp. More complicated nodes can be put into a separate file. All ops inherit from ir::ops::Node and provide a way to lower input ir::Value to a sequence of XlaOp.

Unit Test

Our CircleCI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to xla/test/test_operations.py if it is required. We also need to add CPP tests in xla/test/cpp/test_aten_xla_tensor.cpp. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the aten::op and xla::op counters.

Tips

The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in this pr. You can also find a slightly more complicated example with backward lowering in this pr.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources