PyTorch/XLA documentation¶
PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs.
PyTorch on XLA Devices¶
PyTorch runs on XLA devices, like TPUs, with the torch_xla package. This document describes how to run your models on these devices.
Creating an XLA Tensor¶
PyTorch/XLA adds a new xla
device type to PyTorch. This device type works just
like other PyTorch device types. For example, here’s how to create and
print an XLA tensor:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
This code should look familiar. PyTorch/XLA uses the same interface as regular
PyTorch with a few additions. Importing torch_xla
initializes PyTorch/XLA, and
xm.xla_device()
returns the current XLA device. This may be a CPU or TPU
depending on your environment.
XLA Tensors are PyTorch Tensors¶
PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors.
For example, XLA tensors can be added together:
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
Or matrix multiplied:
print(t0.mm(t1))
Or used with neural network modules:
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor
will throw an error since the torch.nn.Linear
module is on the CPU.
Running Models on XLA Devices¶
Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code. The following snippets highlight these lines when running on a single device and multiple devices with XLA multi-processing.
Running on a Single XLA Device¶
The following snippet shows a network training on a single XLA device:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
This snippet highlights how easy it is to switch your model to run on XLA. The
model definition, dataloader, optimizer and training loop can work on any device.
The only XLA-specific code is a couple lines that acquire the XLA device and
mark the step. Calling
xm.mark_step()
at the end of each training
iteration causes XLA to execute its current graph and update the model’s
parameters. See XLA Tensor Deep Dive for more on
how XLA creates graphs and runs operations.
Running on Multiple XLA Devices with Multi-processing¶
PyTorch/XLA makes it easy to accelerate training by running on multiple XLA devices. The following snippet shows how:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
device = xm.xla_device()
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in mp_device_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
There are three differences between this multi-device snippet and the previous single device snippet. Let’s go over then one by one.
xmp.spawn()
Creates the processes that each run an XLA device.
Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device.
Note that if you print the
xm.xla_device()
on each process you will seexla:0
on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only execution is with PJRT runtime on TPU v2 and TPU v3 since there will be#devices/2
processes and each process will have 2 threads(check this doc for more details).
MpDeviceLoader
Loads the training data onto each device.
MpDeviceLoader
can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance.MpDeviceLoader
also callxm.mark_step
for you everybatches_per_execution
(default to 1) batch being yield.
xm.optimizer_step(optimizer)
Consolidates the gradients between devices and issues the XLA device step computation.
It is pretty much a
all_reduce_gradients
+optimizer.step()
+mark_step
and returns the loss being reduced.
The model definition, optimizer definition and training loop remain the same.
NOTE: It is important to note that, when using multi-processing, the user can start retrieving and accessing XLA devices only from within the target function of
xmp.spawn()
(or any function which hasxmp.spawn()
as parent in the call stack).
See the full multiprocessing example for more on training a network on multiple XLA devices with multi-processing.
Running on TPU Pods¶
Multi-host setup for different accelerators can be very different. This doc will talk about the device independent bits of multi-host training and will use the TPU + PJRT runtime(currently available on 1.13 and 2.x releases) as an example.
Before you being, please take a look at our user guide at here which will explain some Google Cloud basis like how to use gcloud
command and how to setup your project. You can also check here for all Cloud TPU Howto. This doc will focus on the PyTorch/XLA perspective of the Setup.
Let’s assume you have the above mnist example from above section in a train_mnist_xla.py
. If it is a single host multi device training, you would ssh to the TPUVM and run command like
PJRT_DEVICE=TPU python3 train_mnist_xla.py
Now in order to run the same models on a TPU v4-16 (which has 2 host, each with 4 TPU devices), you will need to
Make sure each host can access the training script and training data. This is usually done by using the
gcloud scp
command orgcloud ssh
command to copy the training scripts to all hosts.Run the same training command on all hosts at the same time.
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"
Above gcloud ssh
command will ssh to all hosts in TPUVM Pod and run the same command at the same time..
NOTE: You need to run run above
gcloud
command outside of the TPUVM vm.
The model code and training script is the same for the multi-process training and the multi-host training. PyTorch/XLA and the underlying infrastructure will make sure each device is aware of the global topology and each device’s local and global ordinal. Cross-device communication will happen across all devices instead of local devices.
For more details regarding PJRT runtime and how to run it on pod, please refer to this doc. For more information about PyTorch/XLA and TPU pod and a complete guide to run a resnet50 with fakedata on TPU pod, please refer to this guide.
XLA Tensor Deep Dive¶
Using XLA tensors and devices requires changing only a few lines of code. But even though XLA tensors act a lot like CPU and CUDA tensors, their internals are different. This section describes what makes XLA tensors unique.
XLA Tensors are Lazy¶
CPU and CUDA tensors launch operations immediately or eagerly. XLA tensors, on the other hand, are lazy. They record operations in a graph until the results are needed. Deferring execution like this lets XLA optimize it. A graph of multiple separate operations might be fused into a single optimized operation, for example.
Lazy execution is generally invisible to the caller. PyTorch/XLA automatically constructs the graphs, sends them to XLA devices, and synchronizes when copying data between an XLA device and the CPU. Inserting a barrier when taking an optimizer step explicitly synchronizes the CPU and the XLA device. For more information about our lazy tensor design, you can read this paper.
XLA Tensors and bFloat16¶
PyTorch/XLA can use the
bfloat16
datatype when running on TPUs. In fact, PyTorch/XLA handles float types
(torch.float
and torch.double
) differently on TPUs. This behavior is
controlled by the XLA_USE_BF16
and XLA_DOWNCAST_BF16
environment variable:
By default both
torch.float
andtorch.double
aretorch.float
on TPUs.If
XLA_USE_BF16
is set, thentorch.float
andtorch.double
are bothbfloat16
on TPUs.If
XLA_DOWNCAST_BF16
is set, thentorch.float
isbfloat16
on TPUs andtorch.double
isfloat32
on TPUs.If a PyTorch tensor has
torch.bfloat16
data type, this will be directly mapped to the TPUbfloat16
(XLABF16
primitive type).
Developers should note that XLA tensors on TPUs will always report their PyTorch datatype regardless of the actual datatype they’re using. This conversion is automatic and opaque. If an XLA tensor on a TPU is moved back to the CPU it will be converted from its actual datatype to its PyTorch datatype. Depending on how your code operates, this conversion triggered by the type of processing unit can be important.
Memory Layout¶
The internal data representation of XLA tensors is opaque to the user. They do not expose their storage and they always appear to be contiguous, unlike CPU and CUDA tensors. This allows XLA to adjust a tensor’s memory layout for better performance.
Moving XLA Tensors to and from the CPU¶
XLA tensors can be moved from the CPU to an XLA device and from an XLA device to the CPU. If a view is moved then the data its viewing is also copied to the other device and the view relationship is not preserved. Put another way, once data is copied to another device it has no relationship with its previous device or any tensors on it. Again, depending on how your code operates, appreciating and accommodating this transition can be important.
Saving and Loading XLA Tensors¶
XLA tensors should be moved to the CPU before saving, as in the following snippet:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
tensors = (t0.cpu(), t1.cpu())
torch.save(tensors, 'tensors.pt')
tensors = torch.load('tensors.pt')
t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
This lets you put the loaded tensors on any available device, not just the one on which they were initialized.
Per the above note on moving XLA tensors to the CPU, care must be taken when working with views. Instead of saving views it is recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).
A utility API is provided to save data by taking care of previously moving it to CPU:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
xm.save(model.state_dict(), path)
In case of multiple devices, the above API will only save the data for the master device ordinal (0).
In case where memory is limited compared to the size of the model parameters, an API is provided that reduces the memory footprint on the host:
import torch_xla.utils.serialization as xser
xser.save(model.state_dict(), path)
This API streams XLA tensors to CPU one at a time, reducing the amount of host memory used, but it requires a matching load API to restore:
import torch_xla.utils.serialization as xser
state_dict = xser.load(path)
model.load_state_dict(state_dict)
Directly saving XLA tensors is possible but not recommended. XLA tensors are always loaded back to the device they were saved from, and if that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch, is under active development and this behavior may change in the future.
Compilation Caching¶
The XLA compiler converts the traced HLO into an executable which runs on the devices. Compilation can be time consuming, and in cases where the HLO doesn’t change across executions, the compilation result can be persisted to disk for reuse, significantly reducing development iteration time.
Note that if the HLO changes between executions, a recompilation will still occur.
This is currently an experimental opt-in API, which must be activated before
any computations are executed. Initialization is done through the
initialize_cache
API:
import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)
This will initialize a persistent compilation cache at the specified path. The
readonly
parameter can be used to control whether the worker will be able to
write to the cache, which can be useful when a shared cache mount is used for
an SPMD workload.
Further Reading¶
Additional documentation is available at the PyTorch/XLA repo. More examples of running networks on TPUs are available here.
PyTorch/XLA API¶
torch_xla¶
- torch_xla.device(index: Optional[int] = None) device [source]¶
Returns a given instance of an XLA device.
If SPMD enables, returns a virtual device that wraps all devices available to this process.
- Parameters
index – index of the XLA device to be returned. Corresponds to index in torch_xla.devices().
- Returns
An XLA torch.device.
- torch_xla.devices() List[device] [source]¶
Returns all devices available in the current process.
- Returns
A list of XLA torch.devices.
- torch_xla.device_count() int [source]¶
Returns number of addressable devices in the current process.
- torch_xla.step()[source]¶
Wraps code that should be dispatched to the runtime.
Experimental: xla.step is still a work in progress. Some code that currently works with xla.step but does not follow best practices will become errors in future releases. See https://github.com/pytorch/xla/issues/6751 for context.
runtime¶
- torch_xla.runtime.device_type() Optional[str] [source]¶
Returns the current PjRt device type.
Selects a default device if none has been configured
- torch_xla.runtime.local_process_count() int [source]¶
Returns the number of processes running on this host.
- torch_xla.runtime.local_device_count() int [source]¶
Returns the total number of devices on this host.
Assumes each process has the same number of addressable devices.
- torch_xla.runtime.addressable_device_count() int [source]¶
Returns the number of devices visible to this process.
- torch_xla.runtime.global_device_count() int [source]¶
Returns the total number of devices across all processes/hosts.
- torch_xla.runtime.global_runtime_device_count() int [source]¶
Returns the total number of runtime devices across all processes/hosts, especially useful for SPMD.
- torch_xla.runtime.world_size() int [source]¶
Returns the total number of processes participating in the job.
- torch_xla.runtime.global_ordinal() int [source]¶
Returns global ordinal of this thread within all processes.
Global ordinal is in range [0, global_device_count). Global ordinals are not guaranteed to have any predictable relationship to the TPU worker ID nor are they guaranteed to be contiguous on each host.
- torch_xla.runtime.local_ordinal() int [source]¶
Returns local ordinal of this thread within this host.
Local ordinal is in range [0, local_device_count).
- torch_xla.runtime.get_master_ip() str [source]¶
Retrieve the master worker IP for the runtime. This calls into backend-specific discovery APIs.
Returns master worker’s IP address as a string.
- torch_xla.runtime.initialize_cache(path: str, readonly: bool = False)[source]¶
Initializes the persistent compilation cache. This API must be called before any computations have been performed.
- Parameters
path – The path at which to store the persistent cache.
readonly – Whether or not this worker should have write access to the cache.
xla_model¶
- torch_xla.core.xla_model.xla_device(n=None, devkind=None)[source]¶
Returns a given instance of an XLA device.
- Parameters
n (python:int, optional) – The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of devkind will be returned.
devkind (string..., optional) – If specified, device type such as TPU, CUDA, CPU, or custom PJRT device. Deprecated.
- Returns
A torch.device with the requested instance.
- torch_xla.core.xla_model.xla_device_hw(device)[source]¶
Returns the hardware type of the given device.
- Parameters
device (string or torch.device) – The xla device that will be mapped to the real device.
- Returns
A string representation of the hardware type of the given device.
- torch_xla.core.xla_model.is_master_ordinal(local=True)[source]¶
Checks whether the current process is the master ordinal (0).
- Parameters
local (bool) – Whether the local or global master ordinal should be checked. In case of multi-host replication, there is only one global master ordinal (host 0, device 0), while there are NUM_HOSTS local master ordinals. Default: True
- Returns
A boolean indicating whether the current process is the master ordinal.
- torch_xla.core.xla_model.all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True)[source]¶
Performs an inplace reduce operation on the input tensor(s).
- Parameters
reduce_type (string) – One of
xm.REDUCE_SUM
,xm.REDUCE_MUL
,xm.REDUCE_AND
,xm.REDUCE_OR
,xm.REDUCE_MIN
andxm.REDUCE_MAX
.inputs – Either a single torch.Tensor or a list of torch.Tensor to perform the all reduce op to.
scale (python:float) – A default scaling value to be applied after the reduce. Default: 1.0
groups (list, optional) –
A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]
defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.
pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.
- Returns
If a single torch.Tensor is passed, the return value is a torch.Tensor holding the reduced value (across the replicas). If a list/tuple is passed, this function performs an inplace all-reduce op on the input tensors, and returns the list/tuple itself.
- torch_xla.core.xla_model.all_gather(value, dim=0, groups=None, output=None, pin_layout=True)[source]¶
Performs an all-gather operation along a given dimension.
- Parameters
value (torch.Tensor) – The input tensor.
dim (python:int) – The gather dimension. Default: 0
groups (list, optional) –
A list of list, representing the replica groups for the all_gather() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]
defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.
output (torch.Tensor) – Optional output tensor.
pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.
- Returns
A tensor which has, in the
dim
dimension, all the values from the participating replicas.
- torch_xla.core.xla_model.all_to_all(value, split_dimension, concat_dimension, split_count, groups=None, pin_layout=True)[source]¶
Performs an XLA AllToAll() operation on the input tensor.
See: https://www.tensorflow.org/xla/operation_semantics#alltoall
- Parameters
value (torch.Tensor) – The input tensor.
split_dimension (python:int) – The dimension upon which the split should happen.
concat_dimension (python:int) – The dimension upon which the concat should happen.
split_count (python:int) – The split count.
groups (list, optional) –
A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]
defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.
pin_layout (bool, optional) – whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might cause some xla compilation to fail. Unpin the layout when you see error message like “HloModule has a mix of layout constrained”.
- Returns
The result torch.Tensor of the all_to_all() operation.
- torch_xla.core.xla_model.add_step_closure(closure, args=(), run_async=False)[source]¶
Adds a closure to the list of the ones to be run at the end of the step.
Many times during model training there is the need to print/report (print to console, post to tensorboard, etc…) information which require the content of intermediary tensors to be inspected. Inspecting different tensors content in different points of the model code requires many executions and typically causes performance issues. Adding a step closure will ensure that it will be run after the barrier, when all the live tensors will be already materialized to device data. Live tensors which will include the ones captured by the closure arguments. So using add_step_closure() will ensure a single execution will be performed, even when multiple closures are queued, requiring multiple tensors to be inspected. Step closures will be run sequentially in the order they have been queued. Note that even though using this API the execution will be optimized, it is advised to throttle the printing/reporting events once every N steps.
- Parameters
closure (callable) – The function to be called.
args (tuple) – The arguments to be passed to the closure.
run_async – If True, run the closure asynchronously.
- torch_xla.core.xla_model.wait_device_ops(devices=[])[source]¶
Waits for all the async operations on the given devices to complete.
- Parameters
devices (string..., optional) – The devices whose async ops need to be waited for. If empty, all the local devices will be waited for.
- torch_xla.core.xla_model.optimizer_step(optimizer, barrier=False, optimizer_args={}, groups=None, pin_layout=True)[source]¶
Run the provided optimizer step and issue the XLA device step computation.
- Parameters
optimizer (
torch.Optimizer
) – The torch.Optimizer instance whose step() function needs to be called. The step() function will be called with the optimizer_args named arguments.barrier (bool, optional) – Whether the XLA tensor barrier should be issued in this API. If using the PyTorch XLA ParallelLoader or DataParallel support, this is not necessary as the barrier will be issued by the XLA data loader iterator next() call. Default: False
optimizer_args (dict, optional) – Named arguments dictionary for the optimizer.step() call.
groups (list, optional) –
A list of list, representing the replica groups for the all_reduce() operation. Example: [[0, 1, 2, 3], [4, 5, 6, 7]]
defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.
pin_layout (bool, optional) – whether to pin the layout when reducing gradients. See xm.all_reduce for details.
- Returns
The same value returned by the optimizer.step() call.
- torch_xla.core.xla_model.save(data, file_or_path, master_only=True, global_master=False)[source]¶
Saves the input data into a file.
The saved data is transferred to PyTorch CPU device before being saved, so a following torch.load() will load CPU data. Care must be taken when working with views. Instead of saving views it’s recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).
- Parameters
data – The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, …).
file_or_path – The destination for the data saving operation. Either a file path or a Python file object. If master_only is
False
the path or file objects must point to different destinations as otherwise all the writes from the same host will override each other.master_only (bool, optional) – Whether only the master device should save the data. If False, the file_or_path argument should be a different file or path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True
global_master (bool, optional) – When
master_only
isTrue
this flag controls whether every host’s master (ifglobal_master
isFalse
) saves the content, or only the global master (ordinal 0). Default: Falsesync (bool, optional) – Whether to synchronize all replicas after saving tensors. If True, all replicas must call xm.save or the main process will hang.
- torch_xla.core.xla_model.rendezvous(tag, payload=b'', replicas=[])[source]¶
Waits for all the mesh clients to reach the named rendezvous.
Note: PJRT does not support the XRT mesh server, so this is effectively an alias to xla_rendezvous.
- Parameters
tag (string) – The name of the rendezvous to join.
payload (bytes, optional) – The payload to be sent to the rendezvous.
replicas (list, python:int) – The replica ordinals taking part of the rendezvous. Empty means all replicas in the mesh. Default: []
- Returns
The payloads exchanged by all the other cores, with the payload of core ordinal i at position i in the returned tuple.
- torch_xla.core.xla_model.mesh_reduce(tag, data, reduce_fn)[source]¶
Performs an out-of-graph client mesh reduction.
- Parameters
tag (string) – The name of the rendezvous to join.
data – The data to be reduced. The reduce_fn callable will receive a list with the copies of the same data coming from all the mesh client processes (one per core).
reduce_fn (callable) – A function which receives a list of data-like objects and returns the reduced result.
- Returns
The reduced value.
- torch_xla.core.xla_model.set_rng_state(seed, device=None)[source]¶
Sets the random number generator state.
- Parameters
seed (python:integer) – The state to be set.
device (string, optional) – The device where the RNG state needs to be set. If missing the default device seed will be set.
- torch_xla.core.xla_model.get_rng_state(device=None)[source]¶
Gets the current running random number generator state.
- Parameters
device (string, optional) – The device whose RNG state needs to be retrieved. If missing the default device seed will be set.
- Returns
The RNG state, as integer.
- torch_xla.core.xla_model.get_memory_info(device: device) MemoryInfo [source]¶
Retrieves the device memory usage.
- Parameters
device – The device whose memory information are requested.
- Returns
MemoryInfo dict with memory usage for the given device.
- torch_xla.core.xla_model.get_stablehlo(tensors=None) str [source]¶
Get StableHLO for the computation graph in string format.
If tensors is not empty, the graph with tensors as outputs will be dump. If tensors is empty, the whole computation graph will be dump. TODO(lsy323): When tensors is empty, the some intermediate tensors will also be dump as outputs. Need further investigation.
For inference graph, it is recommended to pass the model outputs to tensors. For training graph, it is not straightforward to identify the “outputs”. Using empty tensors is recommended.
To enable source line info in StableHLO, please set env var XLA_HLO_DEBUG=1.
- Parameters
tensors (list[torch.Tensor], optional) – Tensors that represent the output/root of the StableHLO graph.
- Returns
StableHLO Module in string format.
- torch_xla.core.xla_model.get_stablehlo_bytecode(tensors=None) bytes [source]¶
Get StableHLO for the computation graph in bytecode format.
If tensors is not empty, the graph with tensors as outputs will be dump. If tensors is empty, the whole computation graph will be dump. TODO(lsy323): When tensors is empty, the some intermediate tensors will also be dump as outputs. Need further investigation.
For inference graph, it is recommended to pass the model outputs to tensors. For training graph, it is not straightforward to identify the “outputs”. Using empty tensors is recommended.
- Parameters
tensors (list[torch.Tensor], optional) – Tensors that represent the output/root of the StableHLO graph.
- Returns
StableHLO Module in bytecode format.
distributed¶
- class torch_xla.distributed.parallel_loader.ParallelLoader(loader, devices, batchdim=0, batches_per_execution=1, loader_prefetch_size=8, device_prefetch_size=4, host_to_device_transfer_threads=1, input_sharding=None)[source]¶
Wraps an existing PyTorch DataLoader with background data upload.
- Parameters
loader (
torch.utils.data.DataLoader
) – The PyTorch DataLoader to be wrapped.devices (torch.device…) – The list of devices where the data has to be sent. The i-th sample returned by the loader will be sent to devices[i % len(devices)].
batchdim (python:int, optional) – The dimension which is holding the batch size. Default: 0
loader_prefetch_size (python:int, optional) – The max capacity of the queue used by the thread which is reading samples from the loader, to be processed by the worker threads which upload data to the devices. Default: 8
device_prefetch_size (python:int, optional) – The max size of the per-device queues, where the worker threads deposit tensors which have already been sent to devices. Default: 4
host_to_device_transfer_threads (python:int, optional) – The number of threads that work in parallel to transfer data from loader queue to device queue. Default: 1
input_sharding (ShardingSpec, optional) – Sharding spec to apply to compatible input tensors after loading. Default: None
- per_device_loader(device)[source]¶
Retrieves the loader iterator object for the given device.
- Parameters
device (torch.device) – The device whole loader is being requested.
- Returns
The loader iterator object for the device. This is not a torch.utils.data.DataLoader interface, but a Python iterator which returns the same tensor data structure as returned by the wrapped torch.utils.data.DataLoader, but residing on XLA devices.
- torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]¶
Enables multi processing based replication.
- Parameters
fn (callable) – The function to be called for each device which takes part of the replication. The function will be called with a first argument being the global index of the process within the replication, followed by the arguments passed in args.
args (tuple) – The arguments for fn. Default: Empty tuple
nprocs (python:int) – The number of processes/devices for the replication. At the moment, if specified, can be either 1 or the maximum number of devices.
join (bool) – Whether the call should block waiting for the completion of the processes which have being spawned. Default: True
daemon (bool) – Whether the processes being spawned should have the daemon flag set (see Python multi-processing API). Default: False
start_method (string) – The Python multiprocessing process creation method. Default: spawn
- Returns
The same object returned by the torch.multiprocessing.spawn API. If nprocs is 1 the fn function will be called directly, and the API will return None.
spmd¶
- torch_xla.distributed.spmd.mark_sharding(t: Union[Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Optional[Union[Tuple, int, str]]], use_dynamo_custom_op: bool = False) XLAShardedTensor [source]¶
Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. :param t: input tensor to be annotated with partition_spec. :type t: Union[torch.Tensor, XLAShardedTensor] :param mesh: describes the logical XLA device topology and the underlying device IDs. :type mesh: Mesh :param partition_spec: A tuple of device_mesh dimension index or
None. Each index is an int, str if the mesh axis is named, or tuple of int or str. This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). When a tuple is specified, the corresponding input tensor axis will be sharded along all logical axes in the tuple. Note that the order the mesh axes are specified in the tuple will impact the resulting sharding.
- Parameters
example (For) –
row-wise (we can shard an 8x10 tensor 4-way) –
column-wise. (and replicate) –
torch.randn (>> input =) –
= (>> partition_spec) –
= –
dynamo_custom_op (bool) – if set to True, it calls the dynamo custom op variant of mark_sharding to make itself recognizeable and traceable by dynamo.
Examples ——————————— mesh_shape = (4, 2) num_devices = xr.global_runtime_device_count() device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, (‘x’, ‘y’))
# 4-way data parallel input = torch.randn(8, 32).to(xm.xla_device()) xs.mark_sharding(input, mesh, (0, None))
# 2-way model parallel linear = nn.Linear(32, 10).to(xm.xla_device()) xs.mark_sharding(linear.weight, mesh, (None, 1))
- torch_xla.distributed.spmd.clear_sharding(t: Union[Tensor, XLAShardedTensor]) Tensor [source]¶
Clear sharding annotation from the input tensor and return a cpu casted tensor.
- class torch_xla.distributed.spmd.Mesh(device_ids: Union[ndarray, List], mesh_shape: Tuple[int, ...], axis_names: Optional[Tuple[str, ...]] = None)[source]¶
Describe the logical XLA device topology mesh and the underlying resources.
- Parameters
device_ids (Union[np.ndarray, List]) – A raveled list of devices (IDs) in a custom order. The list is reshaped to an mesh_shape array, filling the elements using C-like index order.
mesh_shape (Tuple[python:int, ...]) – A int tuple describing the logical topology shape of the device mesh, and each element describes the number of devices in the corresponding axis.
axis_names (Tuple[str, ...]) – A sequence of resource axis names to be assigned to the dimensions of the devices argument. Its length should match the rank of devices.
Example: ——————————— mesh_shape = (4, 2) num_devices = len(xm.get_xla_supported_devices()) device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, (‘x’, ‘y’)) mesh.get_logical_mesh() >> array([[0, 1],
[2, 3], [4, 5], [6, 7]])
mesh.shape() >> OrderedDict([(‘x’, 4), (‘y’, 2)])
- class torch_xla.distributed.spmd.HybridMesh(*, ici_mesh_shape: Tuple[int, ...], dcn_mesh_shape: Optional[Tuple[int, ...]] = None, axis_names: Optional[Tuple[str, ...]] = None)[source]¶
- Creates a hybrid device mesh of devices connected with ICI and DCN networks.
The shape of logical mesh should be ordered by increasing network-intensity e.g. [replica, data, model] where mdl has the most network communication requirements.
- Parameters
ici_mesh_shape – shape of the logical mesh for inner connected devices.
dcn_mesh_shape – shape of logical mesh for outer connected devices.
Example
# This example is assuming 2 slices of v4-8. ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor) dcn_mesh_shape = (2, 1, 1)
mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, (‘data’,’fsdp’,’tensor’)) print(mesh.shape()) >> OrderedDict([(‘data’, 2), (‘fsdp’, 4), (‘tensor’, 1)])
- class torch_xla.distributed.spmd.ShardingSpec(mesh: torch_xla.distributed.spmd.xla_sharding.Mesh, partition_spec: Tuple[Optional[int]], minibatch: Optional[bool] = False)[source]¶
experimental¶
debug¶
- torch_xla.debug.metrics.metrics_report()[source]¶
Retrieves a string containing the full metrics and counters report.
- torch_xla.debug.metrics.short_metrics_report(counter_names: Optional[list] = None, metric_names: Optional[list] = None)[source]¶
Retrieves a string containing the full metrics and counters report.
- Parameters
counter_names (list) – The list of counter names whose data needs to be printed.
metric_names (list) – The list of metric names whose data needs to be printed.
- torch_xla.debug.metrics.counter_value(name)[source]¶
Returns the value of an active counter.
- Parameters
name (string) – The name of the counter whose value needs to be retrieved.
- Returns
The counter value as integer.
- torch_xla.debug.metrics.metric_data(name)[source]¶
Returns the data of an active metric.
- Parameters
name (string) – The name of the metric whose data needs to be retrieved.
- Returns
The metric data, which is a tuple of (TOTAL_SAMPLES, ACCUMULATOR, SAMPLES). The TOTAL_SAMPLES is the total number of samples which have been posted to the metric. A metric retains only a given number of samples (in a circular buffer). The ACCUMULATOR is the sum of the samples over TOTAL_SAMPLES. The SAMPLES is a list of (TIME, VALUE) tuples.