torch.distributed.tensor¶
Note
torch.distributed.tensor
is currently in alpha state and under
development, we are committing backward compatibility for the most APIs listed
in the doc, but there might be API changes if necessary.
PyTorch DTensor (Distributed Tensor)¶
PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed
logic, including sharded storage, operator computation and collective communications across devices/hosts.
DTensor
could be used to build different paralleism solutions and support sharded state_dict representation
when working with multi-dimensional sharding.
Please see examples from the PyTorch native parallelism solutions that are built on top of DTensor
:
DTensor
follows the SPMD (single program, multiple data) programming model to empower users to
write distributed program as if it’s a single-device program with the same convergence property. It
provides a uniform tensor sharding layout (DTensor Layout) through specifying the DeviceMesh
and Placement
:
DeviceMesh
represents the device topology and the communicators of the cluster using an n-dimensional array.Placement
describes the sharding layout of the logical tensor on theDeviceMesh
. DTensor supports three types of placements:Shard
,Replicate
andPartial
.
DTensor Class APIs¶
DTensor
is a torch.Tensor
subclass. This means once a DTensor
is created, it could be
used in very similar way to torch.Tensor
, including running different types of PyTorch operators as if
running them in a single device, allowing proper distributed computation for PyTorch operators.
In addition to existing torch.Tensor
methods, it also offers a set of additional methods to interact with
torch.Tensor
, redistribute
the DTensor Layout to a new DTensor, get the full tensor content
on all devices, etc.
- class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)¶
DTensor
(Distributed Tensor) is a subclass oftorch.Tensor
that provides single-device like abstraction to program with multi-devicetorch.Tensor
. It describes the distributed tensor sharding layout (DTensor Layout) through theDeviceMesh
and following types ofPlacement
:Shard
: Tensor sharded on the tensor dimensiondim
on the devices of theDeviceMesh
dimensionReplicate
: Tensor replicated on the devices of theDeviceMesh
dimensionPartial
: Tensor is pending reduction on the devices of theDeviceMesh
dimension
When calling PyTorch operators,
DTensor
overrides the PyTorch operators to perform sharded computation and issue communications whenever necessary. Along with the operator computation,DTensor
will transform or propagate the placements (DTensor Layout) properly (based on the operator semantic itself) and generate newDTensor
outputs.To ensure numerical correctness of the
DTensor
sharded computation when calling PyTorch operators,DTensor
requires every Tensor argument of the operator be DTensor.- Return type
- property device_mesh: DeviceMesh¶
The
DeviceMesh
attribute that associates with this DTensor object.Note
device_mesh
is a read-only property, it can not be set.
- static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source][source]¶
Create a
DTensor
from a local torch.Tensor on each rank according to thedevice_mesh
andplacements
specified.- Parameters
local_tensor (torch.Tensor) – local torch.Tensor on each rank.
device_mesh (
DeviceMesh
, optional) – DeviceMesh to place the tensor, if not specified, must be called under a DeviceMesh context manager, default: Noneplacements (List[
Placement
], optional) – the placements that describes how to place the local torch.Tensor on DeviceMesh, must have the same number of elements asdevice_mesh.ndim
.
- Keyword Arguments
run_check (bool, optional) – at a cost of extra communications, perform sanity check across ranks to check each local tensor’s meta information to ensure correctness. If have
Replicate
inplacements
, the data on first rank of the device mesh dimension will be broadcasted to other ranks. default: Falseshape (torch.Size, optional) – A List of int which specifies the size of DTensor which build on top of local_tensor. Note this needs to be provided if the shape of
local_tensor
are different across the ranks. If not provided,shape
will be computed assuming the given distributed tensor is evenly sharded across ranks. default: Nonestride (tuple, optional) – A List of int which specifies the stride of DTensor. If not provided,
stride
will be computed assuming the given distributed tensor is evenly sharded across ranks. default: None
- Returns
A
DTensor
object- Return type
Note
When
run_check=False
, it is the user’s responsibility to ensure the local tensor passed in is correct across ranks (i.e. the tensor is sharded for theShard(dim)
placement or replicated for theReplicate()
placement). If not, the behavior of the created DTensor is undefined.Note
from_local
is differentiable, the requires_grad of the created DTensor object will depend on if local_tensor requires_grad or not.
- full_tensor(*, grad_placements=None)[source][source]¶
Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate them together. It’s a syntatic sugar of the following code:
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()
- Keyword Arguments
grad_placements (List[
Placement
], optional) – the placements describes the future layout of any gradient layout of the full Tensor returned from this function. full_tensor converts DTensor to a full torch.Tensor and the returned torch.tensor might not be used as the original replicated DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original replicated DTensor layout. If not specified, we will assume the gradient layout of the full tensor be replicated.- Returns
A
torch.Tensor
object that represents the full tensor of this DTensor.- Return type
Note
full_tensor
is differentiable.
- property placements: Tuple[Placement, ...]¶
The placements attribute of this DTensor that describes the layout of this DTensor on the its DeviceMesh.
Note
placements
is a read-only property, it can not be set.
- redistribute(device_mesh=None, placements=None, *, async_op=False)[source][source]¶
redistribute
performs necessary collective operations that redistribute the current DTensor from its current placements to a new placements, or from is current DeviceMesh to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by specifying a Replicate placement for each dimension of the DeviceMesh.When redistributing from current to the new placements on one device mesh dimension, we will perform the following operations including communication collective or local operation:
Shard(dim)
->Replicate()
:all_gather
Shard(src_dim)
->Shard(dst_dim)
:all_to_all
Replicate()
->Shard(dim)
: local chunking (i.e.torch.chunk
)Partial()
->Replicate()
:all_reduce
Partial()
->Shard(dim)
:reduce_scatter
redistribute
would correctly figure out the necessary redistribute steps for DTensors that are created either on 1-D or N-D DeviceMesh.- Parameters
device_mesh (
DeviceMesh
, optional) – DeviceMesh to place the DTensor. If not specified, it would use the current DTensor’s DeviceMesh. default: Noneplacements (List[
Placement
], optional) – the new placements that describes how to place the DTensor into the DeviceMesh, must have the same number of elements asdevice_mesh.ndim
. default: replicate on all mesh dimensions
- Keyword Arguments
async_op (bool, optional) – whether to perform the DTensor redistribute operation asynchronously or not. Default: False
- Returns
A
DTensor
object- Return type
Note
redistribute
is differentiable, which means user do not need to worry about the backward formula of the redistribute operation.Note
redistribute
currently only supports redistributing DTensor on the same DeviceMesh, Please file an issue if you need to redistribute DTensor to different DeviceMesh.
- to_local(*, grad_placements=None)[source][source]¶
Get the local tensor of this DTensor on its current rank. For sharding it returns a local shard of the logical tensor view, for replication it returns the replica on its current rank.
- Keyword Arguments
grad_placements (List[
Placement
], optional) – the placements describes the future layout of any gradient layout of the Tensor returned from this function. to_local converts DTensor to local tensor and the returned local tensor might not be used as the original DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original DTensor layout. If not specified, we will assume the gradient layout remains the same as the original DTensor and use that for gradient computation.- Returns
A
torch.Tensor
orAsyncCollectiveTensor
object. it represents the local tensor on its current rank. When anAsyncCollectiveTensor
object is returned, it means the local tensor is not ready yet (i.e. communication is not finished). In this case, user needs to callwait
to wait the local tensor to be ready.- Return type
Note
to_local
is differentiable, therequires_grad
of the local tensor returned will depend on if the DTensor requires_grad or not.
DeviceMesh as the distributed communicator¶
DeviceMesh
was built from DTensor as the abstraction to describe cluster’s device topology and represent
multi-dimensional communicators (on top of ProcessGroup
). To see the details of how to create/use a DeviceMesh,
please refer to the DeviceMesh recipe.
DTensor Placement Types¶
DTensor supports the following types of Placement
on each DeviceMesh
dimension:
- class torch.distributed.tensor.placement_types.Shard(dim)[source][source]¶
The
Shard(dim)
placement describes the DTensor sharding on tensor dimensiondim
over a correspondingDeviceMesh
dimension, where each rank on the DeviceMesh dimension only holds a shard/piece of the global Tensor. TheShard(dim)
placement follows thetorch.chunk(dim)
semantic, where the last few shards on the DeviceMesh dimension might be empty when the tensor dimension is not evenly divisible on the DeviceMesh dimension. TheShard
placement can be used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.)- Parameters
dim (int) – The tensor dimension that describes the DTensor is sharded over its corresponding DeviceMesh dimension.
Warning
sharding on a tensor dimension where the tensor dimension size is not evenly divisible on a DeviceMesh dimension is currently experimental and subject to change.
- class torch.distributed.tensor.placement_types.Replicate[source][source]¶
The
Replicate()
placement describes the DTensor replicating on a correspondingDeviceMesh
dimension, where each rank on the DeviceMesh dimension holds a replica of the global Tensor. TheReplicate
placement can be used by all DTensor APIs (i.e.distribute_tensor
,DTensor.from_local
, etc.)
- class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source][source]¶
The
Partial(reduce_op)
placement describes the DTensor that is pending reduction on a specifiedDeviceMesh
dimension, where each rank on the DeviceMesh dimension holds the partial value of the global Tensor. User can redistribute thePartial
DTensor to aReplicate
orShard(dim)
placement on the specifiedDeviceMesh
dimension usingredistribute
, which would trigger necessary communication operations under the hood (i.e.allreduce
,reduce_scatter
).- Parameters
reduce_op (str, optional) – The reduction op to be used for the partial DTensor to produce Replicated/Sharded DTensor. Only element-wise reduction operations are supported, including: “sum”, “avg”, “product”, “max”, “min”, default: “sum”.
Note
The
Partial
placement can be generated as a result of the DTensor operators, and can only be used by theDTensor.from_local
API.
- class torch.distributed.tensor.placement_types.Placement[source][source]¶
The base class for the Placement type, where it describes how a DTensor is placed onto the
DeviceMesh
.Placement
andDeviceMesh
together could describe the DTensor Layout. It is the base class of the three main DTensor Placement types:Shard
,Replicate
, andPartial
.This class is not meant to be used directly, mainly served as a typing stub.
Different ways to create a DTensor¶
- There’re three ways to construct a
DTensor
: distribute_tensor()
creates aDTensor
from a logical or “global”torch.Tensor
on each rank. This could be used to shard the leaftorch.Tensor
s (i.e. model parameters/buffers and inputs).DTensor.from_local()
creates aDTensor
from a localtorch.Tensor
on each rank, which can be used to createDTensor
from a non-leaftorch.Tensor
s (i.e. intermediate activation tensors during forward/backward).DTensor provides dedicated tensor factory functions (e.g.
empty()
,ones()
,randn()
, etc.) to allow differentDTensor
creations by directly specifying theDeviceMesh
andPlacement
. Compare todistribute_tensor()
, this could directly materializing the sharded memory on device, instead of performing sharding after initializing the logical Tensor memory.
Create DTensor from a logical torch.Tensor¶
The SPMD (single program, multiple data) programming model in torch.distributed
launches multiple processes
(i.e. via torchrun
) to execute the same program, this means that the model inside the program would be
initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly
on GPU if enough memory).
DTensor
offers a distribute_tensor()
API that could shard the model weights or Tensors to DTensor
s,
where it would create a DTensor from the “logical” Tensor on each process. This would empower the created
DTensor
s to comply with the single device semantic, which is critical for numerical correctness.
- torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None)[source]¶
Distribute a leaf
torch.Tensor
(i.e. nn.Parameter/buffers) to thedevice_mesh
according to theplacements
specified. The rank ofdevice_mesh
andplacements
must be the same. Thetensor
to distribute is the logical or “global” tensor, and the API would use thetensor
from first rank of the DeviceMesh dimension as the source of truth to preserve the single-device semantic. If you want to construct a DTensor in the middle of the Autograd computation, please useDTensor.from_local()
instead.- Parameters
tensor (torch.Tensor) – torch.Tensor to be distributed. Note that if you want to shard a tensor on a dimension that is not evenly divisible by the number of devices in that mesh dimension, we use
torch.chunk
semantic to shard the tensor and scatter the shards. The uneven sharding behavior is experimental and subject to change.device_mesh (
DeviceMesh
, optional) – DeviceMesh to distribute the tensor, if not specified, must be called under a DeviceMesh context manager, default: Noneplacements (List[
Placement
], optional) – the placements that describes how to place the tensor on DeviceMesh, must have the same number of elements asdevice_mesh.ndim
. If not specified, we will by default replicate the tensor across thedevice_mesh
from the first rank of each dimension of the device_mesh.
- Returns
A
DTensor
orXLAShardedTensor
object.- Return type
Note
When initialize the DeviceMesh with the
xla
device_type,distribute_tensor
return XLAShardedTensor instead. see this issue for more details. The XLA integration is experimental and subject to change.
Along with distribute_tensor()
, DTensor also offers a distribute_module()
API to allow easier
sharding on the nn.Module
level
- torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)[source]¶
This function expose three functions to control the parameters/inputs/outputs of the module:
1. To perform sharding on the module before runtime execution by specifying the
partition_fn
(i.e. allow user to convert Module parameters toDTensor
parameters according to the partition_fn specified). 2. To control the inputs or outputs of the module during runtime execution by specifying theinput_fn
andoutput_fn
. (i.e. convert the input toDTensor
, convert the output back totorch.Tensor
)- Parameters
module (
nn.Module
) – user module to be partitioned.device_mesh (
DeviceMesh
) – the device mesh to place the module.partition_fn (Callable) – the function to partition parameters (i.e. shard certain parameters across the
device_mesh
). Ifpartition_fn
is not specified, by default we replicate all module parameters ofmodule
across the mesh.input_fn (Callable) – specify the input distribution, i.e. could control how the input of the module is sharded.
input_fn
will be installed as a moduleforward_pre_hook
(pre forward hook).output_fn (Callable) – specify the output distribution, i.e. could control how the output is sharded, or convert it back to torch.Tensor.
output_fn
will be installed as a moduleforward_hook
(post forward hook).
- Returns
A module that contains parameters/buffers that are all
DTensor
s.- Return type
- Module
Note
When initialize the DeviceMesh with the
xla
device_type,distribute_module
return nn.Module with PyTorch/XLA SPMD annotated parameters. See this issue for more details. The XLA integration is experimental and subject to change.
DTensor Factory Functions¶
DTensor also provides dedicated tensor factory functions to allow creating DTensor
directly
using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally
specifying the DeviceMesh
and Placement
for the DTensor
created:
- torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
Returns a
DTensor
filled with the scalar value 0.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor
. Can be a variable number of arguments or a collection like a list or tuple. E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))- Keyword Arguments
requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor
. Default:False
.dtype (
torch.dtype
, optional) – the desired data type of returnedDTensor
. Default: ifNone
, uses a global default (seetorch.set_default_dtype()
).layout (
torch.layout
, optional) – the desired layout of returnedDTensor
. Default:torch.strided
.device_mesh –
DeviceMesh
type, contains the mesh info of ranksplacements – a sequence of
Placement
type:Shard
,Replicate
- Returns
A
DTensor
object on each rank- Return type
- torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
Returns a
DTensor
filled with the scalar value 1, with the shape defined by the variable argumentsize
.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor
. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))- Keyword Arguments
dtype (
torch.dtype
, optional) – the desired data type of returnedDTensor
. Default: ifNone
, uses a global default (seetorch.set_default_dtype()
).layout (
torch.layout
, optional) – the desired layout of returned DTensor. Default:torch.strided
.requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor
. Default:False
.device_mesh –
DeviceMesh
type, contains the mesh info of ranksplacements – a sequence of
Placement
type:Shard
,Replicate
- Returns
A
DTensor
object on each rank- Return type
- torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
Returns a
DTensor
filled with uninitialized data. The shape of theDTensor
is defined by the variable argumentsize
.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor
. Can be a variable number of arguments or a collection like a list or tuple. E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))- Keyword Arguments
dtype (
torch.dtype
, optional) – the desired data type of returnedDTensor
. Default: ifNone
, uses a global default (seetorch.set_default_dtype()
). layout (torch.layout
, optional): the desired layout of returnedDTensor
. Default:torch.strided
.requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor
. Default:False
.device_mesh –
DeviceMesh
type, contains the mesh info of ranksplacements – a sequence of
Placement
type:Shard
,Replicate
- Returns
A
DTensor
object on each rank- Return type
- torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
Returns a
DTensor
filled withfill_value
according todevice_mesh
andplacements
, with the shape defined by the argumentsize
.- Parameters
- Keyword Arguments
dtype (
torch.dtype
, optional) – the desired data type of returnedDTensor
. Default: ifNone
, uses a global default (seetorch.set_default_dtype()
).layout (
torch.layout
, optional) – the desired layout of returned DTensor. Default:torch.strided
.requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor
. Default:False
.device_mesh –
DeviceMesh
type, contains the mesh info of ranks.placements – a sequence of
Placement
type:Shard
,Replicate
- Returns
A
DTensor
object on each rank- Return type
- torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
Returns a
DTensor
filled with random numbers from a uniform distribution on the interval[0, 1)
. The shape of the tensor is defined by the variable argumentsize
.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor
. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))- Keyword Arguments
dtype (
torch.dtype
, optional) – the desired data type of returnedDTensor
. Default: ifNone
, uses a global default (seetorch.set_default_dtype()
).layout (
torch.layout
, optional) – the desired layout of returned DTensor. Default:torch.strided
.requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor
. Default:False
.device_mesh –
DeviceMesh
type, contains the mesh info of ranks.placements – a sequence of
Placement
type:Shard
,Replicate
- Returns
A
DTensor
object on each rank- Return type
- torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
Returns a
DTensor
filled with random numbers from a normal distribution with mean 0 and variance 1. The shape of the tensor is defined by the variable argumentsize
.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor
. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))- Keyword Arguments
dtype (
torch.dtype
, optional) – the desired data type of returnedDTensor
. Default: ifNone
, uses a global default (seetorch.set_default_dtype()
).layout (
torch.layout
, optional) – the desired layout of returned DTensor. Default:torch.strided
.requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor
. Default:False
.device_mesh –
DeviceMesh
type, contains the mesh info of ranks.placements – a sequence of
Placement
type:Shard
,Replicate
- Returns
A
DTensor
object on each rank- Return type
Debugging¶
Logging¶
When launching the program, you can turn on additional logging using the TORCH_LOGS environment variable from torch._logging :
TORCH_LOGS=+dtensor will display logging.DEBUG messages and all levels above it.
TORCH_LOGS=dtensor will display logging.INFO messages and above.
TORCH_LOGS=-dtensor will display logging.WARNING messages and above.
Debugging Tools¶
To debug the program that applied DTensor, and understand more details about what collectives happened under the
hood, DTensor provides a CommDebugMode
:
- class torch.distributed.tensor.debug.CommDebugMode¶
CommDebugMode
is a context manager that counts the number of functional collectives within its context. It does this using aTorchDispatchMode
.Note
Not all collectives are supported yet.
Example usage
mod = ... comm_mode = CommDebugMode() with comm_mode: mod.sum().backward() print(comm_mode.get_comm_counts())
- generate_comm_debug_tracing_table(noise_level=3)[source][source]¶
Generates detailed table displaying operations and collective tracing information on a module level. Amount of information is dependent on noise_level
prints module-level collective counts
prints dTensor operations not included in trivial operations, module information
prints operations not included in trivial operations
prints all operations
- generate_json_dump(file_name='comm_mode_log.json', noise_level=3)[source][source]¶
Creates json file used to build browser visual 0. prints module-level collective counts 1. prints dTensor operations not included in trivial operations 2. prints operations not included in trivial operations 3. prints all operations
To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor provides visualize_sharding()
:
Experimental Features¶
DTensor
also provides a set of experimental features. These features are either in prototyping stage, or the basic
functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to
these features.
- torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)[source]¶
context_parallel
is an experimental API to enable context parallelism (CP). This API performs two actions: 1) patch the SDPA (torch.nn.functional.scaled_dot_product_attention
) with the CP-enabled one, 2) shardbuffers
along the sequence dimension and each rank will preserve the corresponding shard accordingmesh
.- Parameters
mesh (
DeviceMesh
) – the device mesh for the context parallelism.buffers (Optional[List[torch.Tensor]]) – buffers that the usage depend on the sequence dimension. Examples are input batch, labels and positional embedding buffers. These buffers must be sharded along the sequence dimension to ensure the accuracy. The sharding will happen in-place, the buffer’s shape will change within the context. The buffers will be restored after the context finishes.
no_restore_buffers
can be used to specify which buffers don’t need to be restored. Note thatbuffers
should not contain any nn.Parameter.buffer_seq_dims (Optional[List[int]]) – the sequence dimensions of
buffers
.no_restore_buffers (Optional[Set[torch.Tensor]]) – buffers in these set won’t be restored after the context exits. This set must be a subset of
buffers
. If the buffers won’t be used after the context exits, these buffers can be put in this list to avoid extra restore time.
- Return type
Generator[None, None, None]
Warning
torch.distributed._tensor.experimental.attention.context_parallel is a prototype feature in PyTorch. The API is subject to change.
- torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)[source]¶
local_map()
is an experimental API that allows users to passDTensor
s to a function that is written to be applied ontorch.Tensor
s. It is done by extracting the local components ofDTensor
, call the function, and wrap the outputs toDTensor
according to theout_placements
.- Parameters
func (Callable) – the function to be applied on each local shard of
DTensor
s.out_placements (Union[PlacementType, Tuple[PlacementType, …]]) – the desired placements of the
DTensor
s infunc
’s flattened output. If the flattenedoutput
is a single value, theout_placements
should be of type PlacementType. Otherwise if the flattenedoutput
has multiple values, theout_placements
should be a tuple of PlacementType values 1:1 mapping to the flattenedoutput
. Besides, forTensor
output, we use PlacementType as its placements (a Tuple[Placement] value). For non-Tensor output, the PlacementType should be None. Note that the only exception is when noDTensor
argument is passed in. In this case, even if out_placements is not None, the result function should ignore the desired placements because the function is not running withDTensor
s.in_placements (Tuple[PlacementType, …], optional) – the required placements of the
DTensor
s in the flattened inputs offunc
. Ifin_placements
is specified,local_map()
would examine whether the placements of eachDTensor
argument is the same as the required placements or not. If the placements are not the same andredistribute_inputs
isFalse
, an exception will be raised. Otherwise ifredistribute_inputs
isTrue
, the argument will be first redistributed to the required sharding placements before passing its local tensor tofunc
. The only exception is when required placements are notNone
and the argument is atorch.Tensor
. In this case, the placements examination will be skipped and the argument will be directly passed tofunc
. Ifin_placements
isNone
, no placements examination will be performed. Default: Nonedevice_mesh (
DeviceMesh
, optional) – the device mesh that all theDTensor
s are placed on. If not specified, this will be inferred from the inputDTensor
s’ device mesh. local_map requires everyDTensor
s to be placed on the same device mesh. Default: None.redistribute_inputs (bool, optional) – the bool value indicating whether to reshard the input
DTensor
s when their placements are different from the required input placements. If this value isFalse
and someDTensor
input has a different placement, an exception will be raised. Default: False.
- Returns
A
Callable
that appliesfunc
to each local shard of the inputDTensor
and returns aDTensor
constructed from the return value offunc
.- Raises
AssertionError – If the input
DTensor
is not placed on the same device mesh, or if they are placed on a different device mesh than thedevice_mesh
argument passed in.AssertionError – For any non-DTensor output, we require its corresponding output placement in
out_placements
be None. An AssertionError will be raised if this is not the case.ValueError – If
redistribute_inputs=False
but the inputDTensor
needs a redistribution according toin_placements
.
Example
>>> def mm_allreduce_forward(device_mesh, W, X): >>> partial_sum_tensor = torch.mm(W, X) >>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) >>> return reduced_tensor >>> >>> W = torch.randn(12, 8, requires_grad=False) >>> X = torch.randn(8, 16, requires_grad=False) >>> Y = torch.mm(W, X) >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], >>> in_placements=[col_wise, row_wise], >>> device_mesh=device_mesh, >>> ) >>> >>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors
Note
This API is currently experimental and subject to change
- torch.distributed.tensor.experimental.register_sharding(op)[source]¶
register_sharding()
is an experimental API that allows users to register sharding strategies for an operator when the tensor inputs and outputs are DTensor. It can be useful when: (1) there doesn’t exist a default sharding strategy forop
, e.g. whenop
is a custom operator that is not supported byDTensor
; (2) when users would like to overwrite default sharding strategies of existing operators.- Parameters
op (Union[OpOverload, List[OpOverload]]) – An op or a list of ops to register the customized sharding function.
- Returns
A function decorator which can be used to wrap a function that defines the sharding strategy for the operator specified in
op
. The defined sharding strategy will be registered to DTensor and will override the default sharding strategy if DTensor has already implemented the operator. The customized sharding function takes the same inputs as the original op (except that if an arg is atorch.Tensor
, it will be replaced by a tensor-like object that DTensor uses internally). The function should return a sequence of 2-tuples, each specifying acceptable output placements and its corresponding intput placements.
Example
>>> @register_sharding(aten._softmax.default) >>> def custom_softmax_sharding(x, dim, half_to_float): >>> softmax_dim = dim if dim >= 0 else dim + x.ndim >>> acceptable_shardings = [] >>> >>> all_replicate = ([Replicate()], [Replicate(), None, None]) >>> acceptable_shardings.append(all_replicate) >>> >>> for sharding_dim in range(x.ndim): >>> if sharding_dim != softmax_dim: >>> all_sharded = ( >>> [Shard(sharding_dim)], >>> [Shard(sharding_dim), None, None], >>> ) >>> acceptable_shardings.append(all_sharded) >>> >>> return acceptable_shardings
Note
This API is currently experimental and subject to change