PyTorch on XLA Devices¶
PyTorch runs on XLA devices, like TPUs, with the torch_xla package. This document describes how to run your models on these devices.
Creating an XLA Tensor¶
PyTorch/XLA adds a new xla
device type to PyTorch. This device type works just
like other PyTorch device types. For example, here’s how to create and
print an XLA tensor:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
This code should look familiar. PyTorch/XLA uses the same interface as regular
PyTorch with a few additions. Importing torch_xla
initializes PyTorch/XLA, and
xm.xla_device()
returns the current XLA device. This may be a CPU or TPU
depending on your environment.
XLA Tensors are PyTorch Tensors¶
PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors.
For example, XLA tensors can be added together:
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
Or matrix multiplied:
print(t0.mm(t1))
Or used with neural network modules:
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor
will throw an error since the torch.nn.Linear
module is on the CPU.
Running Models on XLA Devices¶
Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code. The following snippets highlight these lines when running on a single device and multiple devices with XLA multi-processing.
Running on a Single XLA Device¶
The following snippet shows a network training on a single XLA device:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
This snippet highlights how easy it is to switch your model to run on XLA. The
model definition, dataloader, optimizer and training loop can work on any device.
The only XLA-specific code is a couple lines that acquire the XLA device and
mark the step. Calling
xm.mark_step()
at the end of each training
iteration causes XLA to execute its current graph and update the model’s
parameters. See XLA Tensor Deep Dive for more on
how XLA creates graphs and runs operations.
Running on Multiple XLA Devices with Multi-processing¶
PyTorch/XLA makes it easy to accelerate training by running on multiple XLA devices. The following snippet shows how:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
device = xm.xla_device()
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in mp_device_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
There are three differences between this multi-device snippet and the previous single device snippet:
xmp.spawn()
creates the processes that each run an XLA device.MpDeviceLoader
loads the training data onto each device.xm.optimizer_step(optimizer)
consolidates the gradients between cores and issues the XLA device step computation.
The model definition, optimizer definition and training loop remain the same.
NOTE: It is important to note that, when using multi-processing, the user can start retrieving and accessing XLA devices only from within the target function of
xmp.spawn()
(or any function which hasxmp.spawn()
as parent in the call stack).
See the full multiprocessing example for more on training a network on multiple XLA devices with multi-processing.
XLA Tensor Deep Dive¶
Using XLA tensors and devices requires changing only a few lines of code. But even though XLA tensors act a lot like CPU and CUDA tensors, their internals are different. This section describes what makes XLA tensors unique.
XLA Tensors are Lazy¶
CPU and CUDA tensors launch operations immediately or eagerly. XLA tensors, on the other hand, are lazy. They record operations in a graph until the results are needed. Deferring execution like this lets XLA optimize it. A graph of multiple separate operations might be fused into a single optimized operation, for example.
Lazy execution is generally invisible to the caller. PyTorch/XLA automatically constructs the graphs, sends them to XLA devices, and synchronizes when copying data between an XLA device and the CPU. Inserting a barrier when taking an optimizer step explicitly synchronizes the CPU and the XLA device. For more information about our lazy tensor design, you can read this paper.
XLA Tensors and bFloat16¶
PyTorch/XLA can use the
bfloat16
datatype when running on TPUs. In fact, PyTorch/XLA handles float types
(torch.float
and torch.double
) differently on TPUs. This behavior is
controlled by the XLA_USE_BF16
and XLA_DOWNCAST_BF16
environment variable:
By default both
torch.float
andtorch.double
aretorch.float
on TPUs.If
XLA_USE_BF16
is set, thentorch.float
andtorch.double
are bothbfloat16
on TPUs.If
XLA_DOWNCAST_BF16
is set, thentorch.float
isbfloat16
on TPUs andtorch.double
isfloat32
on TPUs.If a PyTorch tensor has
torch.bfloat16
data type, this will be directly mapped to the TPUbfloat16
(XLABF16
primitive type).
Developers should note that XLA tensors on TPUs will always report their PyTorch datatype regardless of the actual datatype they’re using. This conversion is automatic and opaque. If an XLA tensor on a TPU is moved back to the CPU it will be converted from its actual datatype to its PyTorch datatype. Depending on how your code operates, this conversion triggered by the type of processing unit can be important.
Memory Layout¶
The internal data representation of XLA tensors is opaque to the user. They do not expose their storage and they always appear to be contiguous, unlike CPU and CUDA tensors. This allows XLA to adjust a tensor’s memory layout for better performance.
Moving XLA Tensors to and from the CPU¶
XLA tensors can be moved from the CPU to an XLA device and from an XLA device to the CPU. If a view is moved then the data its viewing is also copied to the other device and the view relationship is not preserved. Put another way, once data is copied to another device it has no relationship with its previous device or any tensors on it. Again, depending on how your code operates, appreciating and accommodating this transition can be important.
Saving and Loading XLA Tensors¶
XLA tensors should be moved to the CPU before saving, as in the following snippet:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
tensors = (t0.cpu(), t1.cpu())
torch.save(tensors, 'tensors.pt')
tensors = torch.load('tensors.pt')
t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
This lets you put the loaded tensors on any available device, not just the one on which they were initialized.
Per the above note on moving XLA tensors to the CPU, care must be taken when working with views. Instead of saving views it is recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).
A utility API is provided to save data by taking care of previously moving it to CPU:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
xm.save(model.state_dict(), path)
In case of multiple devices, the above API will only save the data for the master device ordinal (0).
In case where memory is limited compared to the size of the model parameters, an API is provided that reduces the memory footprint on the host:
import torch_xla.utils.serialization as xser
xser.save(model.state_dict(), path)
This API streams XLA tensors to CPU one at a time, reducing the amount of host memory used, but it requires a matching load API to restore:
import torch_xla.utils.serialization as xser
state_dict = xser.load(path)
model.load_state_dict(state_dict)
Directly saving XLA tensors is possible but not recommended. XLA tensors are always loaded back to the device they were saved from, and if that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch, is under active development and this behavior may change in the future.
Further Reading¶
Additional documentation is available at the PyTorch/XLA repo. More examples of running networks on TPUs are available here.
PyTorch/XLA API¶
xla_model¶
-
torch_xla.core.xla_model.
xla_device
(n=None, devkind=None)[source]¶ Returns a given instance of an XLA device.
- Parameters
n (python:int, optional) – The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of devkind will be returned.
devkind (string..., optional) – If specified, one of TPU, GPU or CPU.
- Returns
A torch.device with the requested instance.
-
torch_xla.core.xla_model.
get_xla_supported_devices
(devkind=None, max_devices=None)[source]¶ Returns a list of supported devices of a given kind.
- Parameters
devkind (string..., optional) – If specified, one of TPU, GPU or CPU (the ‘GPU’ XLA device is currently not implemented).
max_devices (python:int, optional) – The maximum number of devices to be returned of that kind.
- Returns
The list of device strings.
-
torch_xla.core.xla_model.
xla_device_hw
(device)[source]¶ Returns the hardware type of the given device.
- Parameters
device (string or torch.device) – The xla device that will be mapped to the real device.
- Returns
A string representation of the hardware type (CPU, TPU, GPU) of the given device.
-
torch_xla.core.xla_model.
get_ordinal
(defval=0)[source]¶ Retrieves the replication ordinal of the current thread.
The ordinals range from 0 to xrt_world_size() minus 1.
- Parameters
defval (python:int, optional) – The default value to be returned in case there is no replication information available. Ignored for PjRt. Default: 0
- Returns
The replication ordinal of the current thread.
-
torch_xla.core.xla_model.
get_local_ordinal
(defval=0)[source]¶ Retrieves the replication local ordinal of the current thread.
The local ordinals range from 0 to the number of local devices minus 1.
- Parameters
defval (python:int, optional) – The default value to be returned in case there is no replication information available. Ignored for PjRt. Default: 0
- Returns
The replication local ordinal of the current thread.
-
torch_xla.core.xla_model.
is_master_ordinal
(local=True)[source]¶ Checks whether the current process is the master ordinal (0).
- Parameters
local (bool) – Whether the local or global master ordinal should be checked. In case of multi-host replication, there is only one global master ordinal (host 0, device 0), while there are NUM_HOSTS local master ordinals. Default: True
- Returns
A boolean indicating whether the current process is the master ordinal.
-
torch_xla.core.xla_model.
xrt_world_size
(defval=1)[source]¶ Retrieves the number of devices which is taking part of the replication.
- Parameters
defval (python:int, optional) – The default value to be returned in case there is no replication information available. Default: 1
- Returns
The number of devices which is taking part of the replication.
-
torch_xla.core.xla_model.
all_reduce
(reduce_type, inputs, scale=1.0, groups=None, cctx=None, pin_layout=True)[source]¶ Performs an inplace reduce operation on the input tensor(s).
- Parameters
reduce_type (string) – One of
xm.REDUCE_SUM
,xm.REDUCE_MUL
,xm.REDUCE_AND
,xm.REDUCE_OR
,xm.REDUCE_MIN
andxm.REDUCE_MAX
.inputs – Either a single torch.Tensor or a list of torch.Tensor to perform the all reduce op to.
scale (python:float) – A default scaling value to be applied after the reduce. Default: 1.0
groups (list, optional) –
A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]
defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.
pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compiation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.
- Returns
If a single torch.Tensor is passed, the return value is a torch.Tensor holding the reduced value (across the replicas). If a list/tuple is passed, this function performs an inplace all-reduce op on the input tensors, and returns the list/tuple itself.
-
torch_xla.core.xla_model.
all_gather
(value, dim=0, groups=None, output=None, pin_layout=True)[source]¶ Performs an all-gather operation along a given dimension.
- Parameters
value (torch.Tensor) – The input tensor.
dim (python:int) – The gather dimension. Default: 0
groups (list, optional) –
A list of list, representing the replica groups for the all_gather() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]
defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.
output (torch.Tensor) – Optional output tensor.
pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compiation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.
- Returns
A tensor which has, in the
dim
dimension, all the values from the participating replicas.
-
torch_xla.core.xla_model.
all_to_all
(value, split_dimension, concat_dimension, split_count, groups=None, pin_layout=True)[source]¶ Performs an XLA AllToAll() operation on the input tensor.
See: https://www.tensorflow.org/xla/operation_semantics#alltoall
- Parameters
value (torch.Tensor) – The input tensor.
split_dimension (python:int) – The dimension upon which the split should happen.
concat_dimension (python:int) – The dimension upon which the concat should happen.
split_count (python:int) – The split count.
groups (list, optional) –
A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]
defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.
pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compiation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.
- Returns
The result torch.Tensor of the all_to_all() operation.
-
torch_xla.core.xla_model.
add_step_closure
(closure, args=(), run_async=False)[source]¶ Adds a closure to the list of the ones to be run at the end of the step.
Many times during model training there is the need to print/report (print to console, post to tensorboard, etc…) information which require the content of intermediary tensors to be inspected. Inspecting different tensors content in different points of the model code requires many executions and typically causes performance issues. Adding a step closure will ensure that it will be run after the barrier, when all the live tensors will be already materialized to device data. Live tensors which will include the ones captured by the closure arguments. So using add_step_closure() will ensure a single execution will be performed, even when multiple closures are queued, requiring multiple tensors to be inspected. Step closures will be run sequentially in the order they have been queued. Note that even though using this API the execution will be optimized, it is advised to throttle the printing/reporting events once every N steps.
- Parameters
closure (callable) – The function to be called.
args (tuple) – The arguments to be passed to the closure.
run_async – If True, run the closure asynchronously.
-
torch_xla.core.xla_model.
wait_device_ops
(devices=[])[source]¶ Waits for all the async operations on the given devices to complete.
- Parameters
devices (string..., optional) – The devices whose async ops need to be waited for. If empty, all the local devices will be waited for.
-
torch_xla.core.xla_model.
optimizer_step
(optimizer, barrier=False, optimizer_args={}, groups=None, pin_layout=True)[source]¶ Run the provided optimizer step and issue the XLA device step computation.
- Parameters
optimizer (
torch.Optimizer
) – The torch.Optimizer instance whose step() function needs to be called. The step() function will be called with the optimizer_args named arguments.barrier (bool, optional) – Whether the XLA tensor barrier should be issued in this API. If using the PyTorch XLA ParallelLoader or DataParallel support, this is not necessary as the barrier will be issued by the XLA data loader iterator next() call. Default: False
optimizer_args (dict, optional) – Named arguments dictionary for the optimizer.step() call.
groups (list, optional) –
A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]
defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.
pin_layout (bool, optional) – whether to pin the layout when reducing gradients. See xm.all_reduce for details.
- Returns
The same value returned by the optimizer.step() call.
-
torch_xla.core.xla_model.
save
(data, file_or_path, master_only=True, global_master=False)[source]¶ Saves the input data into a file.
The saved data is transferred to PyTorch CPU device before being saved, so a following torch.load() will load CPU data. Care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).
- Parameters
data – The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, …).
file_or_path – The destination for the data saving operation. Either a file path or a Python file object. If master_only is
False
the path or file objects must point to different destinations as otherwise all the writes from the same host will override each other.master_only (bool, optional) – Whether only the master device should save the data. If False, the file_or_path argument should be a different file or path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True
global_master (bool, optional) – When
master_only
isTrue
this flag controls whether every host’s master (ifglobal_master
isFalse
) saves the content, or only the global master (ordinal 0). Default: False
-
torch_xla.core.xla_model.
rendezvous
(tag, payload=b'', replicas=[])[source]¶ Waits for all the mesh clients to reach the named rendezvous.
- Parameters
tag (string) – The name of the rendezvous to join.
payload (bytes, optional) – The payload to be sent to the rendezvous.
replicas (list, python:int) – The replica ordinals taking part of the rendezvous. Empty means all replicas in the mesh. Default: []
- Returns
The payloads exchanged by all the other cores, with the payload of core ordinal i at position i in the returned tuple.
-
torch_xla.core.xla_model.
do_on_ordinals
(target, data=(), ordinals=(0, ))[source]¶ Runs a function only on a given set of ordinals.
- Parameters
target (callable) – The function to be run on ordinals.
data – Any input data for the target function which contains tensors. All the XLA tensors used by the target function must be passed in this argument. Every other data used by the function can be captured by the Python interpreter as usual. Default: ()
ordinals (list, python:int) – The list/set of ordinals where the target function should run. Default: (0,)
- Returns
In the ordinals that ran the target function, the function return value, otherwise None.
-
torch_xla.core.xla_model.
mesh_reduce
(tag, data, reduce_fn)[source]¶ Performs an out-of-graph client mesh reduction.
- Parameters
tag (string) – The name of the rendezvous to join.
data – The data to be reduced. The reduce_fn callable will receive a list with the copies of the same data coming from all the mesh client processes (one per core).
reduce_fn (callable) – A function which receives a list of data-like objects and returns the reduced result.
- Returns
The reduced value.
-
torch_xla.core.xla_model.
set_rng_state
(seed, device=None)[source]¶ Sets the random number generator state.
- Parameters
seed (python:integer) – The state to be set.
device (string, optional) – The device where the RNG state needs to be set. If missing the default device seed will be set.
-
torch_xla.core.xla_model.
get_rng_state
(device=None)[source]¶ Gets the current running random number generator state.
- Parameters
device (string, optional) – The device whose RNG state needs to be retrieved. If missing the default device seed will be set.
- Returns
The RNG state, as integer.
-
torch_xla.core.xla_model.
get_memory_info
(device)[source]¶ Retrieves the device memory information.
- Parameters
device (string) – The device whose memory information are requested.
- Returns
A dictionary with kb_free (free memory in KB) and kb_total (total memory in KB) keys.
-
torch_xla.core.functions.
all_reduce
(reduce_type, value, scale=1.0, groups=None)[source]¶ Performs an inplace reduce operation on the input tensor.
This is the same as xm.all_reduce() but supports autograd differentiation.
- Parameters
reduce_type (string) – One of
REDUCE_SUM
,REDUCE_MUL
,REDUCE_AND
,REDUCE_OR
,REDUCE_MIN
andREDUCE_MAX
.value (torch.Tensor) – The to perform the all reduce op to.
scale (python:float) – A default scaling value to be applied after the reduce. Default: 1.0
groups (list, optional) –
A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]
defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.
- Returns
The reduced value across the selected replicas.
-
torch_xla.core.functions.
all_gather
(value, dim=0)[source]¶ Performs an all-gather operation along a given dimension.
This is the same as xm.all_gather() but supports autograd differentiation.
- Parameters
value (torch.Tensor) – The input tensor.
dim (python:int) – The gather dimension. Default: 0
- Returns
A tensor which has, in the
dim
dimension, all the values from the participating replicas.
-
torch_xla.core.functions.
nms
(boxes, scores, score_threshold, iou_threshold, output_size)[source]¶ Performs a Non Maximal Suppression operation.
- Parameters
boxes (torch.Tensor) – A torch.Tensor of shape [N, 4] listing the boxes coordinates in (y0, x0, y1, x1) form.
scores (torch.Tensor) – A torch.Tensor of shape [N] listing the scores of each box.
score_threshold (torch.Tensor) – The minimum score for a box to qualify as valid.
iou_threshold (torch.Tensor) – The minimum IOU (Intersection Over Union) score to trigger overlap logic.
output_size (python:int) – The maximum number of returned indices (must be lower or equal to N).
- Returns
A tuple of torch.Tensor with the first element being the selected box indices, and the second element being the number of valid boxes.
distributed¶
-
class
torch_xla.distributed.parallel_loader.
ParallelLoader
(loader, devices, batchdim=0, batches_per_execution=1, loader_prefetch_size=8, device_prefetch_size=4)[source]¶ Wraps an existing PyTorch DataLoader with background data upload.
- Parameters
loader (
torch.utils.data.DataLoader
) – The PyTorch DataLoader to be wrapped.devices (torch.device…) – The list of devices where the data has to be sent. The i-th sample returned by the loader will be sent to devices[i % len(devices)].
batchdim (python:int, optional) – The dimension which is holding the batch size. Default: 0
loader_prefetch_size (python:int, optional) – The max capacity of the queue used by the thread which is reading samples from the loader, to be processed by the worker threads which upload data to the devices. Default: 8
device_prefetch_size (python:int, optional) – The max size of the per-device queues, where the worker threads deposit tensors which have already been sent to devices. Default: 4
-
per_device_loader
(device)[source]¶ Retrieves the loader iterator object for the given device.
- Parameters
device (torch.device) – The device whole loader is being requested.
- Returns
The loader iterator object for the device. This is not a torch.utils.data.DataLoader interface, but a Python iterator which returns the same tensor data structure as returned by the wrapped torch.utils.data.DataLoader, but residing on XLA devices.
-
torch_xla.distributed.xla_multiprocessing.
spawn
(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]¶ Enables multi processing based replication.
- Parameters
fn (callable) – The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in args.
args (tuple) – The arguments for fn. Default: Empty tuple
nprocs (python:int) – The number of processes/devices for the replication. At the moment, if specified, can be either 1 or the maximum number of devices.
join (bool) – Whether the call should block waiting for the completion of the processes which have being spawned. Default: True
daemon (bool) – Whether the processes being spawned should have the daemon flag set (see Python multi-processing API). Default: False
start_method (string) – The Python multiprocessing process creation method. Default: spawn
- Returns
The same object returned by the torch.multiprocessing.spawn API. If nprocs is 1 the fn function will be called directly, and the API will not return.
-
class
torch_xla.distributed.xla_multiprocessing.
MpModelWrapper
(model)[source]¶ Wraps a model to minimize host memory usage when fork method is used.
This class should be used together with the spawn(…, start_method=’fork’) API to minimize the use of host memory. Instead of creating models on each multiprocessing process, hence replicating the model’s initial host memory, the model is created once at global scope, and then moved into each device inside the spawn() target function. Example:
WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork()) def _mp_fn(index, ...): device = xm.xla_device() model = WRAPPED_MODEL.to(device) ... xmp.spawn(_mp_fn, ..., start_method='fork')
This method has two advantages. First it uses only one copy of the memory pages to host the original model weights, and second it serializes the move of the wrapped model into each device, by lowering the load onto the system memory during the process.
-
class
torch_xla.distributed.xla_multiprocessing.
MpSerialExecutor
[source]¶ Utility to run a function in a serialized fashion among multi-core processes.
Example:
# At global scope. SERIAL_EXEC = xmp.MpSerialExecutor() def load_dataset(path): return maybe_download_and_load(path) def _mp_fn(index, ...): # Avoid all cores downloading the same data with the serial executor. dataset = SERIAL_EXEC.run(lambda: load_dataset('/tmp/mnist-data')) ... xmp.spawn(_mp_fn, ...)
utils¶
-
class
torch_xla.utils.tf_record_reader.
TfRecordReader
(path, compression='', buffer_size=16777216, transforms=None)[source]¶ Reads TfRecords or TfExamples.
- Parameters
path (string) – The path to the file containing TfRecords.
compression (string, optional) – The compression type. The empty string for no compression, otherwise
ZLIB
orGZIP
. Default: No compression.buffer_size (python:int, optional) – The size of the buffer to be used to read TfRecords. Default: 16 * 1024 * 1024
transforms (dict, optional) – A dictionary with the key matching the TfExample label name, and value which is either a callable which will be called to tranform the matching tensor data, or
STR
for string conversion.
-
class
torch_xla.utils.utils.
SampleGenerator
(data, sample_count)[source]¶ Iterator which returns multiple samples of a given input data.
Can be used in place of a PyTorch DataLoader to generate synthetic data.
- Parameters
data – The data which should be returned at each iterator step.
sample_count – The maximum number of data samples to be returned.
-
class
torch_xla.utils.utils.
DataWrapper
[source]¶ Utility class to wrap data structures to be sent to device.
-
torch_xla.utils.serialization.
save
(data, path, master_only=True, global_master=False)[source]¶ Saves the input data into a file.
The saved data is transferred to PyTorch CPU device before being saved, so a following torch.load() will load CPU data. Care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).
- Parameters
data – The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, …).
path – The destination file for the data saving operation. If master_only is
False
the path must point to different destinations as otherwise all the writes from the same host will override each other.master_only (bool, optional) – Whether only the master device should save the data. If False, the path argument should be a different path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True
global_master (bool, optional) – When
master_only
isTrue
this flag controls whether every host’s master (ifglobal_master
isFalse
) saves the content, or only the global master (ordinal 0). Default: False
-
torch_xla.utils.serialization.
load
(path)[source]¶ Loads data previously saved with the save() API.
- Parameters
path (str) – The path passed to the save() API.
- Returns
The loaded data.
-
torch_xla.utils.gcsfs.
open
(path, mode='r', encoding=None)[source]¶ Opens a Google Cloud Storage (GCS) file for reading or writing.
- Parameters
path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where
BUCKET_NAME
is the name of the GCS bucket, andPATH
is a / delimited path.mode (string, optional) – The open mode, similar to the
open()
API. Default: ‘r’encoding (string, optional) – The character encoding to be used to decode bytes into strings when opening in text mode. Default: None
- Returns
The GCS file object.
-
torch_xla.utils.gcsfs.
list
(path)[source]¶ Lists the content of a GCS bucket.
- Parameters
path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where
BUCKET_NAME
is the name of the GCS bucket, andPATH
is a / delimited path.- Returns
A list of
GcsBlob
objects.
-
torch_xla.utils.gcsfs.
stat
(path)[source]¶ Fetches the information of a GCS file.
- Parameters
path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where
BUCKET_NAME
is the name of the GCS bucket, andPATH
is a / delimited path.- Returns
A
GcsBlob
object.
-
torch_xla.utils.gcsfs.
remove
(path)[source]¶ Removes a GCS blob.
- Parameters
path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where
BUCKET_NAME
is the name of the GCS bucket, andPATH
is a / delimited path.
-
torch_xla.utils.gcsfs.
rmtree
(path)[source]¶ Removes all the GCS blobs within a given path.
- Parameters
path (string) –
The GCS path of the file pattern or folder. Must be “gs://BUCKET_NAME/PATH” where
BUCKET_NAME
is the name of the GCSbucket, and
PATH
is a / delimited path.
-
torch_xla.utils.gcsfs.
read
(path)[source]¶ Reads the whole content of a GCS blob.
- Parameters
path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where
BUCKET_NAME
is the name of the GCS bucket, andPATH
is a / delimited path.- Returns
The bytes stored within the GCS blob.
-
torch_xla.utils.gcsfs.
write
(path, content)[source]¶ Write a string/bytes or file into a GCS blob.
- Parameters
path (string) – The GCS path of the file. Must be “gs://BUCKET_NAME/PATH” where
BUCKET_NAME
is the name of the GCS bucket, andPATH
is a / delimited path.content (string, bytes or file object) – The content to be written into
path
.
-
torch_xla.utils.gcsfs.
generic_open
(path, mode='r', encoding=None)[source]¶ Opens a file (GCS or not) for reding or writing.
- Parameters
path (string) –
The path of the file to be opened. If a GCS path, it must be “gs://BUCKET_NAME/PATH” where
BUCKET_NAME
is the name of the GCSbucket, and
PATH
is a / delimited path.mode (string, optional) – The open mode, similar to the
open()
API. Default: ‘r’encoding (string, optional) – The character encoding to be used to decode bytes into strings when opening in text mode. Default: None
- Returns
The opened file object.
-
torch_xla.utils.gcsfs.
generic_read
(path)[source]¶ Reads the whole content of the provided location.
- Parameters
path (string) – The GCS path or local path to be read.
- Returns
The bytes stored within the GCS blob or local file.
-
torch_xla.utils.gcsfs.
generic_write
(output_string, path, makedirs=False)[source]¶ Write a string/bytes or file into a GCS blob or local disk.
Depending on the path passed in, this API can write to local or GCS file. Checks if the path starts with the ‘gs://’ prefix, and uses open otherwise.
- Parameters
output_string (string) – The string to be written to the output.
path (string) – The GCS path or local path of the output.
makedirs (bool) – Whether the path parent folders should be created if missing. Default: False
-
torch_xla.utils.gcsfs.
is_gcs_path
(path)[source]¶ Checks whether a path is a GCS path.
- Parameters
path (string) – The path to be checked.
- Returns
Whether path is a GCS path.
-
class
torch_xla.utils.cached_dataset.
CachedDataset
(data_set, path, max_files_per_folder=1000, compress=True)[source]¶ Wraps an existing torch.utils.data.Dataset by providing file caching.
The CachedDataset can be used to trade the CPU/RAM resources required to process a raw dataset, with storage/network resources. Example:
train_dataset = datasets.MNIST( FLAGS.datadir, train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) train_dataset = CachedDataset(train_dataset, FLAGS.dscache_dir)
The CachedDataset will transparently cache the original Dataset samples, so that every run after the first, will not trigger any more CPU/RAM usage related to the raw samples processing. Once a CachedDataset is fully cached, it can be exported (ie, tar.gz) and used in different machines. Just unpack the tar.gz and pass None as original Dataset: Example:
train_dataset = CachedDataset(None, FLAGS.dscache_dir)
To fully cache CachedDataset just run the warmup() API. A CachedDataset saved on GCS has the advantage to be able to be used from different machines without explicit exporting.
- Parameters
data_set (torch.utils.data.Dataset) – The raw torch.utils.data.Dataset to be cached. It can be set to None in case all the input samples are stored within the path folder.
path (string) – The path where the dataset samples should be stored/loaded. The path needs to be writeable, unless all the samples are already stored. The path can be a GCS path (prefixed with gs://).
max_files_per_folder (python:int) – The maximum amount of files to be stored within a single folder. If data_set is None this value is ignored and taken from the cached metadata. Default: 1000
compress (bool) – Whether the saved samples should be compressed. Compression saves space at the expense of CPU required to compress/decompress. If data_set is None this value is ignored and taken from the cached metadata. Default: True
test¶
Troubleshooting¶
Note that the information in this section is subject to be removed in future releases of the PyTorch/XLA software, since many of them are peculiar to a given internal implementation which might change.
To diagnose issues, we can use the execution metrics and counters provided by PyTorch/XLA The first thing to check when model is slow is to generate a metrics report.
Metrics report is extremely helpful in diagnosing issues. Please try to include it in your bug report sent to us if you have it.
Perform A Auto-Metrics Analysis¶
We provide ways to automatically analyze the metrics report and provide a summary. Simply run your workload with PT_XLA_DEBUG=1
. Some example output would be
pt-xla-profiler: CompileTime too frequent: 21 counts during 11 steps
pt-xla-profiler: TransferFromServerTime too frequent: 11 counts during 11 steps
pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests.
pt-xla-profiler: CompileTime too frequent: 23 counts during 12 steps
pt-xla-profiler: TransferFromServerTime too frequent: 12 counts during 12 steps
Following section will explain how to get and understand a more detial metrics report.
Get A Metrics Report¶
Put the following line in your program to generate a report:
import torch_xla.debug.metrics as met
print(met.metrics_report())
Understand The Metrics Report¶
The report includes things like:
how many time we issue XLA compilations and time spent on issuing.
how many times we execute and time spent on execution
how many device data handles we create/destroy etc.
This information is reported in terms of percentiles of the samples. An example is:
Metric: CompileTime
TotalSamples: 202
Counter: 06m09s401ms746.001us
ValueRate: 778ms572.062us / second
Rate: 0.425201 / second
Percentiles: 1%=001ms32.778us; 5%=001ms61.283us; 10%=001ms79.236us; 20%=001ms110.973us; 50%=001ms228.773us; 80%=001ms339.183us; 90%=001ms434.305us; 95%=002ms921.063us; 99%=21s102ms853.173us
We also provide counters, which are named integer variables which track internal software status. For example:
Counter: CachedSyncTensors
Value: 395
In this report, any counter that starts with aten::
indicates a context switch between the XLA device and CPU, which can be a
potential performance optimization area in the model code.
Counters are useful to understand which operations are routed back to the CPU engine of PyTorch. They are fully qualified with their C++ namespace:
Counter: aten::nonzero
Value: 33
If you see aten::
ops other than nonzero
and _local_scalar_dense
, that usually means a missing
lowering in PyTorch/XLA. Feel free to open a feature request for it on GitHub issues.
Performance Profiling¶
To profile your workload in depth to undertand bottlenecks please check the following resources:
Known Performance Caveats¶
PyTorch/XLA behaves semantically like regular PyTorch and XLA tensors share the full tensor interface with CPU & GPU tensors. However, constraints in XLA/hardware and the lazy evaluation model suggest certain patterns might result in bad performance.
If your model shows bad performance, keep in mind the following caveats:
XLA/TPU yield degraded performance with too many recompilations.
XLA compilation is expensive. PyTorch/XLA automatically recompiles the graph every time new shapes are encountered. Usually models should stabilize within a few steps and you can see huge speedup for the rest of training.
In order to avoid recompilations, not only must shapes be constant, but computations across XLA devices in all hosts should also be constant.
Possible sources:
Direct or indirect uses of
nonzero
introduce dynamic shapes; for example, masked indexingbase[index]
whereindex
is a mask tensor.Loops with a different number of iterations between steps can result in different execution graphs, thus require recompilations.
Solution:
Tensor shapes should be the same between iterations, or a low number of shape variations should be used.
Pad tensors to fixed sizes when possible.
Certain operations don’t have native translations to XLA.
For these operations PyTorch/XLA automatically transfers to the CPU memory, evaluates on CPU, and transfers the result back to the XLA device. Doing too many such operations during the training step can lead to significant slowdowns.
Possible sources:
The
item()
operation explicitly asks to evaluate the result. Don’t use it unless it’s necessary.
Solution:
For most ops we can lower them to XLA to fix it. Checkout metrics report section to find out the missing ops and open a feature request on GitHub.
Even when a PyTorch tensor is known as a scalar, avoid using
tensor.item()
. Keep it as a tensor and use tensor operations on it.Use
torch.where
to substitute control flow when applicable. E.g. The control flow withitem()
used in clip_grad*norm* is problematic and impacts performance, so we have patchedclip_grad_norm_
by callingtorch.where
instead, which gives us a dramatic performance improvement. .. code-block:: python… else:
device = parameters[0].device total_norm = torch.zeros([], device=device if parameters else None) for p in parameters:
param_norm = p.grad.data.norm(norm_type) ** norm_type total_norm.add_(param_norm)
total_norm = (total_norm ** (1. / norm_type))
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) for p in parameters:
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))
Iterators in ``torch_xla.distributed.data_parallel`` may drop the last few batches in the input iterator.
This is to make sure we do the same amount of work on all XLA devices.
Solution:
When dataset is small, and there are too few steps, this may result in a no-op epoch. Therefore, it is better to use small batch sizes in those cases.
XLA Tensor Quirks¶
XLA tensor internals are opaque. XLA tensors always appear to be contiguous and without storage. Networks should not try to check the strides of XLA tensors.
XLA tensors should be moved to the CPU before saving them. Saving XLA tensors directly causes them to be loaded back on the device(s) they were saved from. If a device is unavailable at load time then the load will fail. Moving XLA tensors to the CPU before saving them lets you decide which device(s) to put the loaded tensors on. This is necessary if you want to load the tensors on a machine without XLA devices. Care should be taken moving the XLA tensors to the CPU before saving them, however, as moving tensors across device types does not preserve view relationships. Instead, views should be reconstructed as necessary after the tensors are loaded.
Copying an XLA Tensor with Python’s copy.copy returns a deep copy, not a shallow copy. Use a view of an XLA tensor to get a shallow copy of it.
Handling shared weights. Modules can share weights by setting the Parameters of one module to another. This “tying” of module weights should be done AFTER the modules are moved to an XLA device. Otherwise two independent copies of the shared tensor will be made on the XLA device.
More Debugging Tools¶
We don’t expect users to use tools in this section to debug their models. But we might ask for them when you submit a bug report since they provide additional information that metrics report doesn’t have.
Environment Variables¶
There are also a number of environment variables which control the behavior of the PyTorch/XLA software stack.
Setting such variables will cause different degrees of performance degradation, so they should only be enabled for debugging.
XLA_IR_DEBUG
: Enables the Python stack trace to be captured where creating IR nodes, hence allowing to understand which PyTorch operation was responsible for generating the IR.XLA_HLO_DEBUG
: Enables the Python stack frame captured when _XLA_IRDEBUG is active, to be propagated to the XLA HLO metadata.XLA_SAVE_TENSORS_FILE
: The path to a file which will be used to dump the IR graphs during execution. Note that the file can become really big if the option is left enabled and the PyTorch program let run for long time. The graphs are appended to the file, so to have a clean sheet from run to run, the file should be explicitly removed.XLA_SAVE_TENSORS_FMT
: The format of the graphs stored within the _XLA_SAVE_TENSORSFILE file. Can betext
(the default),dot
(the Graphviz format) orhlo
.XLA_METRICS_FILE
: If set, the path to a local file where the internal metrics will be saved at every step. Metrics will be appended to the file, if already existing.XLA_SAVE_HLO_FILE
: If set, the path to a local file where, in case of compilation/execution error, the offending HLO graph will be saved.XLA_GET_TENSORS_OPBYOP
: Enables pure OpByOp dispatch. The PyTorch/XLA software tries to fuse together many PyTorch operations into a single computation graph, but sometimes, either for debugging, or in case the PyTorch code have a very dynamic nature (in shapes or graph terms), it is better to force the execution in OpByOp mode (every IR node is lowered into a separate XLA computation, and chain-executed). This environment variable, if set to 1, enables OpByOp during the “get tensors” operation (the operation used by PyTorch/XLA to fetch intermediate values back from the TPU device into PyTorch CPU tensors).XLA_SYNC_TENSORS_OPBYOP
: The same as _XLA_GET_TENSORSOPBYOP but for “sync tensors” operation (the operation used at the end of a step, to flush pending IR computations and materialize them into TPU device data).XLA_SYNC_WAIT
: Forces the XLA tensor sync operation to wait for its completion, before moving to the next step.XLA_USE_BF16
: If set to 1, tranforms all the PyTorch Float values into BiFloat16 when sending to the TPU device. Note that when usingXLA_USE_BF16=1
tensor arithmetic will be done in reduced precision and so tensors will not be accurate if accumulated over time. For example:# In reduced bfloat16 precision >>> torch.tensor(4096, dtype=torch.bfloat16) + torch.tensor(1, dtype=torch.bfloat16) tensor(4096., dtype=torch.bfloat16) # Whereas in full float32 precision >>> torch.tensor(4096) + torch.tensor(1) tensor(4097)
So to get accurate metrics such as average loss value over many steps, use manual mixed precision where metrics stay in FP32.
XLA_USE_F16
: If set to 1, tranforms all the PyTorch Float values into Float16 (PyTorch Half type) when sending to devices which supports them.XLA_USE_32BIT_LONG
: If set to 1, maps PyTorch Long types to XLA 32bit type. On the versions of the TPU HW at the time of writing, 64bit integer computations are expensive, so setting this flag might help. It should be verified by the user that truncating to 32bit values is a valid operation according to the use of PyTorch Long values in it.TF_CPP_LOG_THREAD_ID
: If set to 1, the TF logs will show the thread ID helping with debugging multithreaded processes.TF_CPP_VMODULE
: Environment variable used for TF VLOGs and takes the form ofTF_CPP_VMODULE=name=value,...
. Note that for VLOGs you must setTF_CPP_MIN_LOG_LEVEL=0
. For PyTorch/XLA using a configuration likeTF_CPP_VMODULE=tensor=5
would enable logging such as:2019-10-03 17:23:56.419040: I 27891 torch_xla/csrc/tensor.cpp:1104] Executing IR graph hash 4211381954965020633 on device TPU:3 done! 2019-10-03 17:23:56.419448: I 27890 torch_xla/csrc/tensor.cpp:1104] Executing IR graph hash 15483856951158150605 on device TPU:5 done! 2019-10-03 17:23:56.419539: I 27896 torch_xla/csrc/tensor.cpp:1104] Executing IR graph hash 4211381954965020633 on device TPU:4 done! ...
TF_CPP_MIN_LOG_LEVEL
: Level to print messages for.TF_CPP_MIN_LOG_LEVEL=0
will turn on INFO logging,TF_CPP_MIN_LOG_LEVEL=1
WARNING and so on. Our PyTorch/XLATF_VLOG
usestensorflow::INFO
level by default so to see VLOGs setTF_CPP_MIN_LOG_LEVEL=0
.XLA_DUMP_HLO_GRAPH
: If set to=1
in case of a compilation or execution error the offending HLO graph will be dumped as part of the runtime error raised byxla_util.cc
.
Retrieving Stack Traces¶
In the event that the PyTorch process is hanging, it might be useful to include the stack traces together with the GitHub issue.
First thing is to find out which PID the PyTorch process is associated with. Using the ps
command it is possible to find that information. It will be a python process running your
main python file.
In order to allow GDB to attach a user process the following command should be run as root:
echo 0 > /proc/sys/kernel/yama/ptrace_scope
The above command remains active until the machine is rebooted.
The, given the PID, it is possible to grab the stack traces with the following command:
./scripts/dump_stacks.py PID > /tmp/stack-traces.log
Using debug_run.py To Collect Debug Information¶
A utility is provided in scripts/debug_run.py
which can be used to create a tar.gz
archive with the information required to debug PyTorch/XLA executions.
Example:
./scripts/debug_run.py --outfile /tmp/debug_run.tar.gz -- python -u SCRIPT [ARGS...]
The python -u
flag is suggested to disable buffering so that captured logs are correctly
interleaved (otherwise STDOUT will be rendered after all STDERR).
The above command line example will leave the temporary folder containing the archived
information on the filesystem. Use the --tidy
flag to have that removed on exit:
./scripts/debug_run.py --tidy --outfile /tmp/debug_run.tar.gz -- python -u SCRIPT [ARGS...]
The debug_run.tar.gz
file should then be attached to bug reports when necessary.
Since the script will collect a lot of data, it should usually be let run for no more than hundred steps or so.
If the SCRIPT has arguments to control the number of steps, those should be used,
otherwise hitting CTRL^C
will interrupt the run.
It is also sugested to run in single-core mode, to minimize the amount of data. Running in single-core mode is also strongly suggested when debugging execution issues.
Common Issues¶
Missing XLA configuration
error message: You need to setXRT_TPU_CONFIG
if using TPUs. If using GPUs setGPU_NUM_DEVICES=N
forN
number of GPUs. If using CPUs setXRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0"
andXRT_WORKERS="localservice:0;grpc://localhost:9002"
Experimental PjRt Runtime Support¶
The PyTorch/XLA team is currently migrating from the currently-supported XRT runtime to the PjRt runtime used by JAX. Although PjRt may work on TPU v2 and v3, we plan on making PjRt the officially supported runtime for PyTorch/XLA on TPU v4 and future generations of TPU.
PjRt is available as an experimental preview in PyTorch/XLA r1.13. The
PyTorch/XLA team will provide limited support on a best-effort basis during this
preview. If you encounter a bug with PjRt, please file an issue on GitHub with
the runtime
tag.
This preview is mainly targeted at TPU v4. In most cases, we expect that you can re-use your existing PyTorch/XLA code for TPU v4 with no changes. You may be able to adapt your v2 or v3 workload to PjRt with some caveats (see below).
Quickstart¶
To start using PjRt with PyTorch/XLA, all you need to do is set the
PJRT_DEVICE
environment variable. If you’re working on a TPU v2 or v3, keep
reading to learn about the differences between TPU v2 and v3 and v4.
CPU¶
On any machine with PyTorch/XLA installed, you can run our MNIST example on CPU like this:
PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data
v4 TPU¶
To create a new TPU with PyTorch/XLA r1.13 installed:
gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-1.13 --zone=us-central2-b --project=$PROJECT
On a v4-8, you can run our ResNet50 example like this:
git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
By default, PjRt will use all TPU chips. To use only one TPU chip, configure
TPU_PROCESS_BOUNDS
and TPU_VISIBLE_CHIPS
:
TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
Note: xmp.spawn
’s nprocs
argument is not implemented for PjRt.
Pods¶
On TPU Pods, use gcloud
to run your command on each TPU in parallel:
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
GPU¶
Coming soon in a future release!
Key differences from XRT¶
Although in most cases we expect PjRt and XRT to work mostly interchangeably from the end-user’s perspective (especially on TPU v4), there are some subtle differences that are important to keep in mind. Importantly, XRT was designed around the TPU Node architecture, so it will always spawn a client and a server process, even on TPU VMs. Thus, every batch of inputs has additional latency from serializing and deserializing data.
PjRt uses the local device directly with no intermediate server process. In the default configuration, PjRt will create one process per TPU chip, or 4 processes per TPU host. See the Cloud TPU documentation for more information about TPU architecture.
Performance gains are possible for workloads constrained by data transfer speeds.
Under XRT, the server process is the only process that interacts with the TPU devices, and client processes don’t have direct access to the TPU devices. When profiling a single-host TPU (e.g. v3-8 or v4-8), you would normally see 8 device traces (one for each TPU core). With PjRt, each process has one chip, and a profile from that process will show only 2 TPU cores.
For the same reason, profiling does not work on TPU Pods with XRT, because the server process runs independently from the user’s model code. PjRt does not have that constraint, so it is possible to profile 2 TPU cores per process in a TPU Pod.
PjRt only supports the TPU VM architecture and we have no plans to support the TPU Node architecture with PjRt.
Runtime configuration is significantly simpler with PjRt.
xla_dist
is not required to run TPU Pod workloads. Instead, copy your code to each TPU host (``gcloud compute tpus tpu-vm scp` <https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp>`_) and run the code on each host in parallel (e.g. ``gcloud compute tpus tpu-vm ssh –workers=all –command=”PJRT_DEVICE=TPU python run.py”` <https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh>`_)
TPUs v2/v3 vs v4¶
On TPU v4, one TPU chip is represented to PyTorch as one device, while on TPUs
v2/v3, one TPU chip is represented to PyTorch as two devices. It is not
possible to access the same TPU chip from multiple processes, so workloads must
be able to handle two devices per process. The easiest way to handle this is to
spawn two threads per process on TPU v2/v3, which is done automatically by
xmp.spawn
when using PjRt. With multiple threads per process, multiple replicas
will share global state, causing the following known issues:
Threads will share the same
torch
random seed used for parameter initialization. If you relied on each process having the same random seed for deterministic parameter initialization, you will have to synchronize module parameters via collective broadcasting instead (e.g.pjrt.broadcast_master_param(model)
).torch.distributed
uses a global process group and does not support multi-threading, so thexla
torch.distributed
backend will not work with PjRt and TPU v2 and v3 at this time.Because the current implementation of
xm.rendezvous
for PjRt relies ontorch.distributed
,xm.rendezvous
is not supported with PjRt on TPU v2 and v3.
Compatible examples¶
For an overview of the changes required to migrate from TPU v2/v3 to v4, compare our MNIST (XRT, PjRt) and ImageNet (XRT, PjRt) examples.
The PjRt MNIST and ImageNet examples are compatible with all versions of TPU. Use the following commands to run them on a single-host TPU (e.g. v3-8 or v4-8).
PJRT_DEVICE=TPU python3 xla/test/pjrt/test_train_pjrt_mnist.py --fake_data
PJRT_DEVICE=TPU python3 xla/test/pjrt/test_train_pjrt_imagenet.py --fake_data
PjRt and DDP¶
PjRt composes really well with [the new experimental
torch.nn.parallel.DistributedDataParallel feature] (./ddp.md) on TPU V4. Just
run the DDP script as usual but with PJRT_DEVICE=TPU
. Here is a full example:
PJRT_DEVICE=TPU MASTER_ADDR=localhost MASTER_PORT=6000 python xla/test/test_train_mp_mnist.py --ddp --fake_data --num_epochs 1
Caveat: for TPU V2 and V3, however, XRT will still be needed to run DDP.
How to do DistributedDataParallel
¶
This document shows how to use torch.nn.parallel.DistributedDataParallel in xla, and further describes its difference against the native xla data parallel approach.
Background / Motivation¶
Customers have long requested the ability to use PyTorch’s DistributedDataParallel API with xla. And here we enable it as an experimental feature.
How to use DistributedDataParallel¶
For those who switched from the PyTorch eager mode to XLA, here are all the changes you need to do to convert your eager DDP model into XLA model. We assume that you already know how to use XLA on a single device.
Import xla specific distributed packages:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
Init xla process group similar to other process groups such as nccl and gloo.
dist.init_process_group("xla", rank=rank, world_size=world_size)
Use xla specific APIs to get rank and world_size if you need to.
new_rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
Pass
gradient_as_bucket_view=True
to the DDP wrapper.
ddp_model = DDP(model, gradient_as_bucket_view=True)
Finally launch your model with xla specific launcher.
xmp.spawn(demo_fn)
Here we have put everything together (the example is actually taken from the DDP tutorial). The way you code it is pretty similar to the eager experience. Just with xla specific touches on a single device plus the above five changes to your script.
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
# additional imports for xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the xla process group
dist.init_process_group("xla", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 1000000)
self.relu = nn.ReLU()
self.net2 = nn.Linear(1000000, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank):
# xla specific APIs to get rank, world_size.
new_rank = xm.get_ordinal()
assert new_rank == rank
world_size = xm.xrt_world_size()
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to XLA device
device = xm.xla_device()
model = ToyModel().to(device)
# currently, graident_as_bucket_view is needed to make DDP work for xla
ddp_model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10).to(device))
labels = torch.randn(20, 5).to(device)
loss_fn(outputs, labels).backward()
optimizer.step()
# xla specific API to execute the graph
xm.mark_step()
cleanup()
def run_demo(demo_fn):
# xla specific launcher
xmp.spawn(demo_fn)
if __name__ == "__main__":
run_demo(demo_basic)
Benchmarking¶
Resnet50 with fake data¶
The following results are collected with the command: python
test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1
on a
TPU VM V3-8 environment with ToT PyTorch and PyTorch/XLA. And the statistical
metrics are produced by using the script in this pull
request. The unit for the rate is
images per second.
Type | Mean | Median | 90th % | Std Dev | CV |
xm.optimizer_step | 418.54 | 419.22 | 430.40 | 9.76 | 0.02 |
DDP | 395.97 | 395.54 | 407.13 | 7.60 | 0.02 |
The performance difference between our native approach for distributed data parallel and DistributedDataParallel wrapper is: 1 - 395.97 / 418.54 = 5.39%. This result seems reasonable given the DDP wrapper introduces extra overheads on tracing the DDP runtime.
MNIST with fake data¶
The following results are collected with the command: python
test/test_train_mp_mnist.py --fake_data
on a TPU VM V3-8 environment with ToT
PyTorch and PyTorch/XLA. And the statistical metrics are produced by using the
script in this pull request. The
unit for the rate is images per second.
Type | Mean | Median | 90th % | Std Dev | CV |
xm.optimizer_step | 17864.19 | 20108.96 | 24351.74 | 5866.83 | 0.33 |
DDP | 10701.39 | 11770.00 | 14313.78 | 3102.92 | 0.29 |
The performance difference between our native approach for distributed data parallel and DistributedDataParallel wrapper is: 1 - 14313.78 / 24351.74 = 41.22%. Here we compare 90th % instead since the dataset is small and first a few rounds are heavily impacted by data loading. This slowdown is huge but makes sense given the model is small. The additional DDP runtime tracing overhead is hard to amortize.
MNIST with real data¶
The following results are collected with the command: python
test/test_train_mp_mnist.py --logdir mnist/
on a TPU VM V3-8 environment with
ToT PyTorch and PyTorch/XLA.
And we can observe that the DDP wrapper converges slower than the native XLA approach even though it still achieves a high accuracy rate at 97.48% at the end. (The native approach achieves 99%.)
Disclaimer¶
This feature is still experimental and under active development. Use it in cautions and feel free to file any bugs to the xla github repo. For those who are interested in the native xla data parallel approach, here is the tutorial.
Here are some of the known issues that are under investigation:
gradient_as_bucket_view=True
needs to be enforced.There are some issues while being used with
torch.utils.data.DataLoader
.test_train_mp_mnist.py
with real data crashes before exiting.
Fully Sharded Data Parallel (FSDP) in PyTorch XLA¶
Fully Sharded Data Parallel (FSDP) in PyTorch XLA is a utility for sharding Module parameters across data-parallel workers.
Example usage:
import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters.
Notes:
The
XlaFullyShardedDataParallel
class supports both the ZeRO-2 optimizer (sharding gradients and optimizer states) and the ZeRO-3 optimizer (sharding parameters, gradients, and optimizer states) in https://arxiv.org/abs/1910.02054.The ZeRO-3 optimizer should be implemented via nested FSDP with
reshard_after_forward=True
. Seetest/test_train_mp_mnist_fsdp_with_ckpt.py
andtest/test_train_mp_imagenet_fsdp.py
for an example.For large models that cannot fit into a single TPU memory or the host CPU memory, one should interleave submodule construction with inner FSDP wrapping. See ``FSDPViTModel` <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>`_ for an example.
a simple wrapper
checkpoint_module
is provided (based ontorch_xla.utils.checkpoint.checkpoint
from https://github.com/pytorch/xla/pull/3524) to perform gradient checkpointing over a givennn.Module
instance. Seetest/test_train_mp_mnist_fsdp_with_ckpt.py
andtest/test_train_mp_imagenet_fsdp.py
for an example.When stepping the optimizer, directly call
optimizer.step
and do not callxm.optimizer_step
. The latter reduces the gradient across ranks, which is not needed for FSDP (where the parameters are already sharded).When saving model and optimizer checkpoints during training, each training process needs to save its own checkpoint of the (sharded) model and optimizer state dicts (use
master_only=False
and set different paths for each rank inxm.save
). When resuming, it needs to load the checkpoint for the corresponding rank.Please also save
model.get_shard_metadata()
along withmodel.state_dict()
as follows and useconsolidate_sharded_model_checkpoints
to stitch the sharded model checkpoints together into a full model state dict. Seetest/test_train_mp_mnist_fsdp_with_ckpt.py
for an example.ckpt = { 'model': model.state_dict(), 'shard_metadata': model.get_shard_metadata(), 'optimizer': optimizer.state_dict(), } ckpt_path = f'/tmp/rank-{xm.get_ordinal()}-of-{xm.xrt_world_size()}.pth' xm.save(ckpt, ckpt_path, master_only=False)
The checkpoint consolidation script can also be launched from the command line as follows. .. code-block:: bash
# consolidate the saved checkpoints via command line tool python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank--of-.pth”
The implementation of this class is largely inspired by and mostly follows the structure of fairscale.nn.FullyShardedDataParallel
in https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html. One of the biggest differences from fairscale.nn.FullyShardedDataParallel
is that in XLA we don’t have explicit parameter storage, so here we resort to a different approach to free full parameters for ZeRO-3.
Example training scripts on MNIST and ImageNet¶
MNIST: ``test/test_train_mp_mnist_fsdp_with_ckpt.py` <https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_fsdp_with_ckpt.py>`_ (it also tests checkpoint consolidation)
ImageNet: ``test/test_train_mp_imagenet_fsdp.py` <https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_fsdp.py>`_
Installation¶
FSDP is available on PyTorch/XLA 1.12 release and newer nightly. Please refer to https://github.com/pytorch/xla#-available-images-and-wheels for installation guide.
Clone PyTorch/XLA repo¶
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
Train MNIST on v3-8 TPU¶
It gets around 98.9 accuracy for 2 epochs:
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp --use_gradient_checkpointing
This script automatically tests checkpoint consolidation at the end. You can also manually consolidate the sharded checkpoints via
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
Train ImageNet with ResNet-50 on v3-8 TPU¶
It gets around 75.9 accuracy for 100 epochs; download ImageNet-1k to /datasets/imagenet-1k
:
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
--use_nested_fsdp
You can also add --use_gradient_checkpointing
(which needs to be used along with --use_nested_fsdp
) to apply gradient checkpointing on the residual blocks.
Example training scripts on TPU pod (with 10 billion parameters)¶
To train large models that cannot fit into a single TPU, one should use nested FSDP (wrapping sub-modules with inner FSDP when building the entire model) to implement the ZeRO-3 algorithm.
Please see https://github.com/ronghanghu/vit_10b_fsdp_example for an example of sharded training of a Vision Transformer (ViT) model using this XLA FSDP PR.
OP Lowering Guide¶
Background¶
PyTorch wraps the C++ ATen tensor library that offers a wide range of operations implemented on GPU and CPU. Pytorch/XLA is a PyTorch extension; one of its purposes is to convert PyTorch operations to XLA operations. Lowering defines a process of converting a higher-level representation to a lower-level representation. In this document, I will refer to the process of converting PyTorch operation to XLA operation as the lowering. XLA Compiler will also lower XlaOp to HLO, but that’s beyond the scope of this documentation. We will forward operations that we haven’t provided an XLA lowering yet to CPU and call ATen implementations. Operations that are forwarded to the CPU will cause a significant slowdown. We must lower all operations used in the model to achieve the best performance.
Before you start¶
You should follow the instructions in here to install required dependencies and build pytorch and pytorch/XLA from the source. You do not need access to TPU to implement the lowering. It is recommended to experiment on a workstation and configure it to use XLA:CPU. You can configure Pytorch/XLA to use XLA:CPU by running
export XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0" XRT_WORKERS="localservice:0;grpc://localhost:51011"
Understanding the operation¶
You can find the definition of the C++ ATen operations in native_functions.yaml. After you build Pytorch/XLA from source, you will also find our default implementation (a boxed kernel which forwards calls to PyTorch native CPU) in xla/torch_xla/csrc/aten_cpu_fallback.h/cpp
. Pytorch operations can usually be mapped to PyTorch tensor api easily. If that is not the case searching the PyTorch native implementation under PyTorch repo is recommended. The goal is to lower the PyTorch operations into a sequence of XLA operations defined in here.
File structure¶
All file mentioned below lives under the xla/torch_xla/csrc
folder, with the exception of xla_native_functions.yaml
xla_native_functions.yaml
contains the list of all operators that are lowered. Each operator name must directly match a pytorch operator listed in native_functions.yaml. This file serves as the interface to adding new xla operators, and is an input to PyTorch’s codegen machinery. It generates the below 3 files:XLANativeFunctions.h
,RegisterXLA.cpp
, andRegisterAutogradXLA.cpp
XLANativeFunctions.h
andaten_xla_type.cpp
are entry points of PyTorch to the pytorch_xla world, and contain the manually written lowerings to XLA for each operator.XLANativeFunctions.h
is auto-generated through a combination ofxla_native_functions.yaml
and the PyTorch corenative_functions.yaml
file, and contains declarations for kernels that need to be defined inaten_xla_type.cpp
. The kernels written here need to construct ‘XLATensor’ using the inputat::Tensor
and other parameters. The resultingXLATensor
needs to be converted back to theat::Tensor
before returning to the PyTorch world.RegisterXLA.cpp
andRegisterAutogradXLA.cpp
are auto-generated files that register all lowerings to the PyTorch Dispatcher. They also include auto-generated wrapper implementations ofout=
andinplace
operators.aten_cpu_fallback.h/.cpp
contain our boxed fallback implementation to CPU. The boxed fallback kernel will be used if a lowering is not explicitly defined inxla_native_functions.yaml
+aten_xla_type.cpp
, and the operator is not composite.tensor.h
contains theXLATensor
declarations. These declarations are usually a one to one mapping of theat::Tensor
nodes we declared inXLANativeFunctions.h
tensor_methods.cpp
contains the implementation ofXLATensor node
defined intensor.h
. We constructed the correspondingir::op
from the parameter’sir::Value
and wrapped it inside aXLATensor
. Ir stands for intermediate representation.ops/
directory contains allir::ops
declaration and definition. Smaller nodes can be put inops/ops.h/.cpp
. More complicated nodes can be put into a separate file. All ops inherit fromir::ops::Node
and provide a way to lower inputir::Value
to a sequence ofXlaOp
.
Unit Test¶
Our CircleCI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to xla/test/test_operations.py
if it is required. We also need to add CPP tests in xla/test/cpp/test_aten_xla_tensor.cpp
. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the aten::op
and xla::op
counters.
Tips¶
The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in this pr. You can also find a slightly more complicated example with backward lowering in this pr.
We have auto-generated wrapper implementations of out=
and inplace
operators for some operators in RegisterXLA.cpp
. We only need to lower the vanilla op in this case. An example would be lerp
operator which has 6 variants in native_functions.yaml
, they are
- lerp_.Scalar
- lerp_.Tensor
- lerp.Scalar_out
- lerp.Tensor_out
- lerp.Scalar
- lerp.Tensor
and will generate function prototypes
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out);
in XLANativeFunctions.h
if we add all of them to the xla_native_functions.yaml
. However if we only lower lerp.Scalar
and lerp.Tensor
and check RegisterXLA.cpp
, we will see
namespace {
at::Tensor wrapper_Scalar_lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
// No device check
// DeviceGuard omitted
return torch_xla::lerp(self, end, weight);
}
} // anonymous namespace
at::Tensor & wrapper_Scalar_lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
auto wrapper_Scalar_lerp__tmp = wrapper_Scalar_lerp(self, end, weight);
at::_copy_from(wrapper_Scalar_lerp__tmp, self);
return self;
}
...
m.impl("lerp_.Scalar",
TORCH_FN(wrapper_Scalar_lerp_));
The codegen will automatically generate lowerings for lerp_.Scalar
and lerp.Scalar_out
that use our lerp.Scalar
implementation, without us having to provide an explicit lowering.
In general, if there is an operator in pytorch core that has both an out-of-place and an out= variant, it’s better to write a lowering for the out-of-place variant, since you’ll get a code-generated out= lowering for free.
For each node we need to pass an ir::OpKind
. Here is an (example). You can find the OpKind
definition in aten_interned_strings.h or interned_strings.h. If the aten symbol is missing, you can submit a PR like this.