PyTorch/XLA SPMD User Guide¶
In this user guide, we discuss how GSPMD is integrated in PyTorch/XLA, and provide a design overview to illustrate how the SPMD sharding annotation API and its constructs work.
What is PyTorch/XLA SPMD?¶
GSPMD is an automatic parallelization system for common ML workloads. The XLA compiler will transform the single device program into a partitioned one with proper collectives, based on the user provided sharding hints. This feature allows developers to write PyTorch programs as if they are on a single large device without any custom sharded computation ops and/or collective communications to scale.
*Figure 1. Comparison of two different execution strategies, (a) for non-SPMD and (b) for SPMD.*
How to use PyTorch/XLA SPMD?¶
Here is an simple example of using SPMD
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh
# Enable XLA SPMD execution mode.
xr.use_spmd()
# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))
t = torch.randn(8, 4).to(xm.xla_device())
# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)
Let’s explain these concepts one by one
SPMD Mode¶
In order to use SPMD, you need to enable it via xr.use_spmd()
. In SPMD mode there is only one logical device. Distributed computation and collective is handled by the mark_sharding
. Note that user can not mix SPMD with other distributed libraries.
Mesh¶
For a given cluster of devices, a physical mesh is a representation of the interconnect topology.
mesh_shape
is a tuple that will be multiplied to the total number of physical devices.device_ids
is almost alwaysnp.array(range(num_devices))
.Users are also encouraged to give each mesh dimension a name. In the above example, the first mesh dimension is the
data
dimension and the second mesh dimension is themodel
dimension.
You can also check more mesh info via
>>> mesh.shape()
OrderedDict([('data', 4), ('model', 1)])
Partition Spec¶
partition_spec has the same rank as the input tensor. Each dimension describes how the corresponding input tensor dimension is sharded across the device mesh. In the above example tensor t
’s fist dimension is being sharded at data
dimension and the second dimension is being sharded at model
dimension.
User can also shard tensor that has different dimensions from the mesh shape.
t1 = torch.randn(8, 8, 16).to(device)
t2 = torch.randn(8).to(device)
# First dimension is being replicated.
xs.mark_sharding(t1, mesh, (None, 'data', 'model'))
# First dimension is being sharded at data dimension.
# model dimension is used for replication when omitted.
xs.mark_sharding(t2, mesh, ('data',))
# First dimension is sharded across both mesh axes.
xs.mark_sharding( t2, mesh, (('data', 'model'),))
Fully Sharded Data Parallel(FSDP) via SPMD¶
Fully Sharded Data Parallel via SPMD or FSDPv2 is an utility that re-expresses the famous FSDP algorithm in SPMD. This is an experimental feature that aiming to offer a familiar interface for users to enjoy all the benefits that SPMD brings into the table. The design doc is here.
Please review the SPMD user guide before proceeding. You can also find a minimum runnable example here.
Example usage:
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
# Define the mesh following common SPMD practice
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'model'))
# Shard the input, and assume x is a 2D tensor.
x = xs.mark_sharding(x, mesh, ('fsdp', None))
# As normal FSDP, but an extra mesh is needed.
model = FSDPv2(my_module, mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters. Here is an example to autowrap each DecoderLayer
.
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
# Apply FSDP sharding on each DecoderLayer layer.
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
decoder_only_model.DecoderLayer
},
)
model = FSDPv2(
model, mesh=mesh, auto_wrap_policy=auto_wrap_policy)
Sharding output¶
To ensure the XLA compiler correctly implements the FSDP algorithm, we need to shard both weights and activations. This means sharding the output of the forward method. Since the forward function output can vary, we offer shard_output to shard activations in cases where your module output doesn’t fall into one of these categories:
A single tensor
A tuple of tensors where the 0th element is the activation.
Example usage:
def shard_output(output, mesh):
xs.mark_sharding(output.logits, mesh, ('fsdp', None, None))
model = FSDPv2(my_module, mesh, shard_output)
Gradient checkpointing¶
Currently, gradient checkpointing needs to be applied to the module before the FSDP wrapper. Otherwise, recursively loop into children modules will end up with infinite loop. We will fix this issue in the future releases.
Example usage:
from torch_xla.distributed.fsdp import checkpoint_module
model = FSDPv2(checkpoint_module(my_module), mesh)
PyTorch/XLA SPMD advanced topics¶
In this doc we will cover some advance topic on GSPMD. Please read SPMD user guide before procedding to this doc.
PyTorch/XLA SPMD takes a single-device program, shards and executes it in parallel. The SPMD execution requires using the native PyTorch DataLoader, which transfers data synchronously from the host to XLA devices. This blocks the training during the input data transfer every step. To improve the native data loading performance, we made PyTorch/XLA ParallelLoader support input sharding directly (src), when passed the optional kwarg _input_sharding_:
# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# assume 4d input and we want to shard at the batch dimension.
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
It is also possible to specify a different input_sharding
for each element of the batch if they are different shapes:
# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}
# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# specify different sharding for each input of the batch.
input_sharding={
'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
'y': xs.ShardingSpec(input_mesh, ('data', None))
}
)
PyTorch/XLA normally transfers tensor data asynchronously from host to device once the tensor is defined. This is to overlap the data transfer with the graph tracing time. However, because GSPMD allows the user to modify the tensor sharding _after _the tensor has been defined, we need an optimization to prevent unnecessary transfer of tensor data back and forth between host and device. We introduce Virtual Device Optimization, a technique to place the tensor data on a virtual device SPMD:0 first, before uploading to the physical devices when all the sharding decisions are finalized. Every tensor data in SPMD mode is placed on a virtual device, SPMD:0. The virtual device is exposed to the user as an XLA device XLA:0 with the actual shards on physical devices, like TPU:0, TPU:1, etc.
Hybrid Mesh¶
Mesh nicely abstracts how the physical device mesh is constructed. Users can arrange devices in any shape and order using the logical mesh. However, one can define a more performant mesh based on the physical topology, especially when it involves Data Center Network (DCN) cross slice connections. HybridMesh creates a mesh which gives good performance out of the box for such multislice environments. It accepts ici_mesh_shape and dcn_mesh_shape which denote logical mesh shapes of inner and outer network.
from torch_xla.distributed.spmd import HybridMesh
# This example is assuming 2 slices of v4-8.
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
# - dcn_mesh_shape: shape of logical mesh for outer connected devices.
ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
dcn_mesh_shape = (2, 1, 1)
mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
print(mesh.shape())
>> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])
Running SPMD on TPU Pod¶
There is no code change required to go from single TPU host to TPU Pod if you construct your mesh and partition spec based on the number of devices instead of some hardcode constant. To run the PyTorch/XLA workload on TPU Pod, please refer to the Pods section of our PJRT guide.
XLAShardedTensor¶
xs.mark_sharding
is a inplace op that will attach the sharding annotation to the input tensor, but it also return a XLAShardedTensor
python object.
The main use case for XLAShardedTensor
[RFC] is to annotate a native torch.tensor
(on a single device) with a sharding spec. The annotation takes place immediately, but the actual sharding of the tensor is delayed as the computation is carried out lazily, except for the input tensors which are sharded without delay. Once a tensor is annotated and wrapped inside a XLAShardedTensor
, it can be passed to existing PyTorch ops and nn.Module
layers as torch.Tensor
. This is important to ensure that the same PyTorch layers and tensor ops can be stacked together with XLAShardedTensor
. This means that the user does not need to rewrite the existing ops and model codes for sharded computation. Namely, XLAShardedTensor
will satisfy the following requirements:
XLAShardedTensor
is atorch.Tensor
subclass and works directly with native torch ops andmodule.layers
. We use__torch_dispatch__
to sendXLAShardedTensor
to the XLA backend. PyTorch/XLA retrieves attached sharding annotations to trace the graph and invokes XLA SPMDPartitioner.Internally,
XLAShardedTensor
(and its global_tensor input) is backed byXLATensor
with a special data structure holding references to the sharded device data.The sharded tensor after lazy execution may be gathered and materialized back to the host as global_tensor when requested on the host (e.g., printing the value of the global tensor.
The handles to the local shards are materialized strictly after the lazy execution.
XLAShardedTensor
exposes local_shards to return the local shards on addressable devices asList[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]
.
There is also an ongoing effort to integrate XLAShardedTensor
into DistributedTensor
API to support XLA backend [RFC].
DTensor Integration¶
PyTorch has prototype-released DTensor in 2.1.
We are integrating PyTorch/XLA SPMD into DTensor API RFC. We have a proof-of-concept integration for distribute_tensor
, which calls mark_sharding
annotation API to shard a tensor and its computation using XLA:
import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor
# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD.
mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])
This feature is experimental and stay tuned for more updates, examples and tutorials in the upcoming releases.
Activation Sharding for torch.compile¶
In the 2.3 release, PyTorch/XLA added the custom op dynamo_mark_sharding
which can be used to perform the activation sharding in a torch.compile
region. This is part of our ongoing effort to make torch.compile
+ GSPMD
to be the recommended way of doing the model inference using PyTorch/XLA. Example of using this custom op:
# Activation output sharding
device_ids = [i for i in range(self.num_devices)] # List[int]
mesh_shape = [self.num_devices//2, 1, 2] # List[int]
axis_names = "('data', 'model')" # string version of axis_names
partition_spec = "('data', 'model')" # string version of partition spec
torch.ops.xla.dynamo_mark_sharding(output, device_ids, mesh_shape, axis_names, partition_spec)
SPMD Debugging Tool¶
We provide a shard placement visualization debug tool
for PyTorch/XLA SPMD user on TPU/GPU/CPU with single-host/multi-host: you could use visualize_tensor_sharding
to visualize sharded tensor, or you could use visualize_sharding
to visualize sharing string. Here are two code examples on TPU single-host(v4-8) with visualize_tensor_sharding
or visualize_sharding
:
Code snippet used
visualize_tensor_sharding
and visualization result:
import rich
# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
t = torch.randn(8, 4, device='xla')
xs.mark_sharding(t, mesh, ('x', 'y'))
# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
generated_table = visualize_tensor_sharding(t, use_color=False)
Code snippet used
visualize_sharding
and visualization result:
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,2]0,1,2,3}'
generated_table = visualize_sharding(sharding, use_color=False)
You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style tiled
, partial_replication
and replicated
.
Auto-Sharding¶
We are introducing a new PyTorch/XLA SPMD feature, called auto-sharding
, RFC. This is an experimental feature in r2.3
and nightly
, that supports XLA:TPU
and a single TPUVM host.
PyTorch/XLA auto-sharding can be enabled by one of the following:
Setting envvar
XLA_AUTO_SPMD=1
Calling the SPMD API in the beginning of your code:
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
Calling
pytorch.distributed._tensor.distribute_module
withauto-policy
andxla
:
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
# Currently, model should be loaded to xla device via distribute_module.
model = MyModule() # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)
Optionally, one can set the following options/env-vars to control the behvaior of the XLA-based auto-sharding pass:
XLA_AUTO_USE_GROUP_SHARDING
: group resharding of the parameters. Set by default.XLA_AUTO_SPMD_MESH
: logical mesh shape to be used for auto-sharding. For example,XLA_AUTO_SPMD_MESH=2,2
corresponds to a 2-by-2 mesh with 4 global devices. If unset, a default device mesh shape ofnum_devices,1
will be used.
Distributed Checkpointing¶
PyTorch/XLA SPMD is compatible with the torch.distributed.checkpoint library through a dedicated Planner
instance. Users are able to synchronously save and load checkpoints through this common interface.
The SPMDSavePlanner and SPMDLoadPlanner (src) classes enable the save
and load
functions to operate directly on the shards of an XLAShardedTensor
, enabling all of the benefits of distributed checkpointing in SPMD training.
Here is a demonstration of the synchronous distributed checkpointing API:
import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc
# Saving a state_dict
state_dict = {
"model": model.state_dict(),
"optim": optim.state_dict(),
}
dist_cp.save(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
planner=xc.SPMDSavePlanner(),
)
...
# Loading the model's state_dict from the checkpoint. The model should
# already be on the XLA device and have the desired sharding applied.
state_dict = {
"model": model.state_dict(),
}
dist_cp.load(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["model"])
The experimental CheckpointManager
interface provides a higher-level API over the torch.distributed.checkpoint
functions to enable a few key features:
Managed checkpoints: Each checkpoint taken by the
CheckpointManager
is identified by the step at which it was taken. All steps tracked are accessible through theCheckpointManager.all_steps
method, and any tracked steps can be restored usingCheckpointManager.restore
.Asynchronous checkpointing: Checkpoints taken through the
CheckpointManager.save_async
API are written to persistent storage asynchronously to unblock training for the duration of the checkpoint. The input sharded state_dict is first moved to CPU before the checkpoint is dispatched to a background thread.Auto-checkpointing on preemption: On Cloud TPU, preemptions can be detected and a checkpoint taken before the process is terminated. To use, ensure your TPU is provisioned through a QueuedResource with Autocheckpointing enabled, and ensure the
chkpt_on_preemption
parameter is set when constructing the CheckpointManager (this option is enabled by default).FSSpec Support:
CheckpointManager
uses an fsspec storage backend to enable checkpointing directly to any fsspec-compatible filesystem, including GCS.
Example usage of the CheckpointManager is below:
from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer
# Create a CheckpointManager to checkpoint every 10 steps into GCS.
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)
# Select a checkpoint to restore from, and restore if applicable
tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
# Choose the highest step
best_step = max(tracked_steps)
# Before restoring the checkpoint, the optimizer state must be primed
# to allow state to be loaded into it.
prime_optimizer(optim)
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
chkpt_mgr.restore(best_step, state_dict)
model.load_state_dict(state_dict['model'])
optim.load_state_dict(state_dict['optim'])
# Call `save` or `save_async` every step within the train loop. These methods
# return True when a checkpoint is taken.
for step, data in enumerate(dataloader):
...
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
if chkpt_mgr.save_async(step, state_dict):
print(f'Checkpoint taken at step {step}')
In distributed checkpointing, the state_dicts are loaded in-place, and only the
required shards of the checkpoint are loaded. Since optimizer states are lazily
created, the state isn’t present until the first optimizer.step
call, and
attempts to load an unprimed optimizer will fail.
The utility method prime_optimizer
is provided for this: it runs a fake train
step by setting all gradients to zero and calling optimizer.step
. This is a
destructive method and will touch both model parameters and optimizer state,
so it should only be called just prior to restoration.
To use torch.distributed
APIs such as distributed checkpointing, a process
group is required. In SPMD mode, the xla
backend is not supported since the
compiler is responsible for all collectives.
Instead, a CPU process group such as gloo
must be used. On TPUs, the xla://
init_method is still supported to discover the master IP, global world size,
and host rank. An example initialization is below:
import torch.distributed as dist
# Import to register the `xla://` init_method
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr
xr.use_spmd()
# The `xla://` init_method will automatically discover master worker IP, rank,
# and global world size without requiring environment configuration on TPUs.
dist.init_process_group('gloo', init_method='xla://')