torchrec.distributed¶
Torchrec Distributed
Torchrec distributed provides the necessary modules and operations to enable model parallelism.
These include:
model parallelism through DistributedModelParallel.
collective operations for comms, including All-to-All and Reduce-Scatter.
collective operations wrappers for sparse features, KJT, and various embedding types.
sharded implementations of various modules including ShardedEmbeddingBag for nn.EmbeddingBag, ShardedEmbeddingBagCollection for EmbeddingBagCollection
embedding sharders that define sharding for any sharded module implementation.
support for various compute kernels, which are optimized for compute device (CPU/GPU) and may include batching together embedding tables and/or optimizer fusion.
pipelined training through TrainPipelineSparseDist that overlaps dataloading device transfer (copy to GPU), inter*device communications (input_dist), and computation (forward, backward) for increased performance.
quantization support for reduced precision training and inference.
torchrec.distributed.collective_utils¶
This file contains utilities for constructing collective based control flows.
- torchrec.distributed.collective_utils.invoke_on_rank_and_broadcast_result(pg: torch._C._distributed_c10d.ProcessGroup, rank: int, func: Callable[[...], torchrec.distributed.collective_utils.T], *args: Any, **kwargs: Any) torchrec.distributed.collective_utils.T ¶
Invokes a function on the designated rank and broadcasts the result to all members within the group.
Example:
id = invoke_on_rank_and_broadcast_result(pg, 0, allocate_id)
- torchrec.distributed.collective_utils.is_leader(pg: Optional[torch._C._distributed_c10d.ProcessGroup], leader_rank: int = 0) bool ¶
Checks if the current processs is the leader.
- Parameters
pg (Optional[dist.ProcessGroup]) – the process’s rank within the pg is used to determine if the process is the leader. pg being None implies that the process is the only member in the group (e.g. a single process program).
leader_rank (int) – the definition of leader (defaults to 0). The caller can override it with a context-specific definition.
- torchrec.distributed.collective_utils.run_on_leader(pg: torch._C._distributed_c10d.ProcessGroup, rank: int)¶
torchrec.distributed.comm¶
- torchrec.distributed.comm.get_group_rank(world_size: Optional[int] = None, rank: Optional[int] = None) int ¶
Gets the group rank of the worker group. Also available with GROUP_RANK environment varible A number between 0 and get_num_groups() (See https://pytorch.org/docs/stable/elastic/run.html)
- torchrec.distributed.comm.get_local_rank(world_size: Optional[int] = None, rank: Optional[int] = None) int ¶
Gets the local rank of the local processes (see https://pytorch.org/docs/stable/elastic/run.html) This is usually the rank of the worker on its node
- torchrec.distributed.comm.get_local_size(world_size: Optional[int] = None) int ¶
- torchrec.distributed.comm.get_num_groups(world_size: Optional[int] = None) int ¶
Gets the number of worker groups. Usually equivalent to max_nnodes (See https://pytorch.org/docs/stable/elastic/run.html)
- torchrec.distributed.comm.intra_and_cross_node_pg(device: Optional[torch.device] = None, backend: str = 'nccl') Tuple[Optional[torch._C._distributed_c10d.ProcessGroup], Optional[torch._C._distributed_c10d.ProcessGroup]] ¶
Creates sub process groups (intra and cross node)
torchrec.distributed.comm_ops¶
- class torchrec.distributed.comm_ops.All2AllDenseInfo(output_splits: List[int], batch_size: int, input_shape: List[int], input_splits: List[int])¶
Bases:
object
The data class that collects the attributes when calling the alltoall_dense operation.
- batch_size: int¶
- input_shape: List[int]¶
- input_splits: List[int]¶
- output_splits: List[int]¶
- class torchrec.distributed.comm_ops.All2AllPooledInfo(batch_size_per_rank: List[int], dim_sum_per_rank: List[int], dim_sum_per_rank_tensor: Optional[torch.Tensor], cumsum_dim_sum_per_rank_tensor: Optional[torch.Tensor])¶
Bases:
object
The data class that collects the attributes when calling the alltoall_pooled operation.
- batch_size_per_rank¶
batch size in each rank
- Type
List[int]
- dim_sum_per_rank¶
number of features (sum of dimensions) of the embedding in each rank.
- Type
List[int]
- dim_sum_per_rank_tensor¶
the tensor version of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.
- Type
Optional[Tensor]
- cumsum_dim_sum_per_rank_tensor¶
cumulative sum of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.
- Type
Optional[Tensor]
- B_local¶
local batch size before scattering.
- Type
int
- batch_size_per_rank: List[int]¶
- cumsum_dim_sum_per_rank_tensor: Optional[torch.Tensor]¶
- dim_sum_per_rank: List[int]¶
- dim_sum_per_rank_tensor: Optional[torch.Tensor]¶
- class torchrec.distributed.comm_ops.All2AllSequenceInfo(embedding_dim: int, lengths_after_sparse_data_all2all: torch.Tensor, forward_recat_tensor: torch.Tensor, backward_recat_tensor: torch.Tensor, input_splits: List[int], output_splits: List[int], permuted_lengths_after_sparse_data_all2all: Optional[torch.Tensor] = None)¶
Bases:
object
The data class that collects the attributes when calling the alltoall_sequence operation.
- embedding_dim¶
embedding dimension.
- Type
int
- lengths_after_sparse_data_all2all¶
lengths of sparse features after AlltoAll.
- Type
Tensor
- forward_recat_tensor¶
recat tensor for forward.
- Type
Tensor
- backward_recat_tensor¶
recat tensor for backward.
- Type
Tensor
- input_splits¶
input splits.
- Type
List[int]
- output_splits¶
output splits.
- Type
List[int]
- lengths_sparse_before_features_all2all¶
lengths of sparse features before AlltoAll.
- Type
Optional[Tensor]
- backward_recat_tensor: torch.Tensor¶
- embedding_dim: int¶
- forward_recat_tensor: torch.Tensor¶
- input_splits: List[int]¶
- lengths_after_sparse_data_all2all: torch.Tensor¶
- output_splits: List[int]¶
- permuted_lengths_after_sparse_data_all2all: Optional[torch.Tensor] = None¶
- class torchrec.distributed.comm_ops.All2AllVInfo(dims_sum_per_rank: typing.List[int], B_global: int, B_local: int, B_local_list: typing.List[int], D_local_list: typing.List[int], input_split_sizes: typing.List[int] = <factory>, output_split_sizes: typing.List[int] = <factory>)¶
Bases:
object
The data class that collects the attributes when calling the alltoallv operation.
- dim_sum_per_rank¶
number of features (sum of dimensions) of the embedding in each rank.
- Type
List[int]
- B_global¶
global batch size for each rank.
- Type
int
- B_local¶
local batch size before scattering.
- Type
int
- B_local_list¶
(List[int]): local batch sizes for each embedding table locally (in my current rank).
- Type
List[int]
- D_local_list¶
embedding dimension of each embedding table locally (in my current rank).
- Type
List[int]
- input_split_sizes¶
The input split sizes for each rank, this remembers how to split the input when doing the all_to_all_single operation.
- Type
List[int]
- output_split_sizes¶
The output split sizes for each rank, this remembers how to fill the output when doing the all_to_all_single operation.
- Type
List[int]
- B_global: int¶
- B_local: int¶
- B_local_list: List[int]¶
- D_local_list: List[int]¶
- dims_sum_per_rank: List[int]¶
- input_split_sizes: List[int]¶
- output_split_sizes: List[int]¶
- class torchrec.distributed.comm_ops.All2All_Pooled_Req(*args, **kwargs)¶
Bases:
torch.autograd.function.Function
- static backward(ctx, *unused) Tuple[None, None, None, torch.Tensor] ¶
Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computated w.r.t. the output.
- static forward(ctx, pg: torch._C._distributed_c10d.ProcessGroup, myreq: torchrec.distributed.comm_ops.Request[torch.Tensor], a2ai: torchrec.distributed.comm_ops.All2AllPooledInfo, input_embeddings: torch.Tensor) torch.Tensor ¶
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.All2All_Pooled_Wait(*args, **kwargs)¶
Bases:
torch.autograd.function.Function
- static backward(ctx, grad_output: torch.Tensor) Tuple[None, None, torch.Tensor] ¶
Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computated w.r.t. the output.
- static forward(ctx, pg: torch._C._distributed_c10d.ProcessGroup, myreq: torchrec.distributed.comm_ops.Request[torch.Tensor], sharded_output_embeddings: torch.Tensor) torch.Tensor ¶
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.All2All_Seq_Req(*args, **kwargs)¶
Bases:
torch.autograd.function.Function
- static backward(ctx, *unused) Tuple[None, None, None, torch.Tensor] ¶
Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computated w.r.t. the output.
- static forward(ctx, pg: torch._C._distributed_c10d.ProcessGroup, myreq: torchrec.distributed.comm_ops.Request[torch.Tensor], a2ai: torchrec.distributed.comm_ops.All2AllSequenceInfo, sharded_input_embeddings: torch.Tensor) torch.Tensor ¶
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.All2All_Seq_Req_Wait(*args, **kwargs)¶
Bases:
torch.autograd.function.Function
- static backward(ctx, sharded_grad_output: torch.Tensor) Tuple[None, None, torch.Tensor] ¶
Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computated w.r.t. the output.
- static forward(ctx, pg: torch._C._distributed_c10d.ProcessGroup, myreq: torchrec.distributed.comm_ops.Request[torch.Tensor], sharded_output_embeddings: torch.Tensor) torch.Tensor ¶
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.All2Allv_Req(*args, **kwargs)¶
Bases:
torch.autograd.function.Function
- static backward(ctx, *grad_output)¶
Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computated w.r.t. the output.
- static forward(ctx, pg: torch._C._distributed_c10d.ProcessGroup, myreq: torchrec.distributed.comm_ops.Request[torch.Tensor], a2ai: torchrec.distributed.comm_ops.All2AllVInfo, inputs: List[torch.Tensor]) torch.Tensor ¶
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.All2Allv_Wait(*args, **kwargs)¶
Bases:
torch.autograd.function.Function
- static backward(ctx, *grad_outputs) Tuple[None, None, torch.Tensor] ¶
Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computated w.r.t. the output.
- static forward(ctx, pg: torch._C._distributed_c10d.ProcessGroup, myreq, output) Tuple[torch.Tensor] ¶
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.ReduceScatterInfo(input_sizes: List[int])¶
Bases:
object
The data class that collects the attributes when calling the reduce_scatter_pooled operation.
- input_sizes¶
the sizes of the input tensors. This remembers the sizes of the input tensors when running the backward pass and producing the gradient.
- Type
List[int]
- input_sizes: List[int]¶
- class torchrec.distributed.comm_ops.ReduceScatter_Req(*args, **kwargs)¶
Bases:
torch.autograd.function.Function
- static backward(ctx, *unused: torch.Tensor) Tuple[Optional[torch.Tensor], ...] ¶
Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computated w.r.t. the output.
- static forward(ctx, pg: torch._C._distributed_c10d.ProcessGroup, myreq: torchrec.distributed.comm_ops.Request[torch.Tensor], rsi: torchrec.distributed.comm_ops.ReduceScatterInfo, *inputs: Any) torch.Tensor ¶
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.ReduceScatter_Wait(*args, **kwargs)¶
Bases:
torch.autograd.function.Function
- static backward(ctx, grad_output: torch.Tensor) Tuple[None, None, torch.Tensor] ¶
Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computated w.r.t. the output.
- static forward(ctx, pg: torch._C._distributed_c10d.ProcessGroup, myreq: torchrec.distributed.comm_ops.Request[torch.Tensor], output: torch.Tensor) torch.Tensor ¶
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class torchrec.distributed.comm_ops.Request(pg: torch._C._distributed_c10d.ProcessGroup)¶
Bases:
torchrec.distributed.types.Awaitable
[torchrec.distributed.comm_ops.W
]Defines a collective operation request for a process group on a tensor.
- Parameters
pg (dist.ProcessGroup) – The process group the request is for.
- torchrec.distributed.comm_ops.alltoall_pooled(a2a_pooled_embs_tensor: torch.Tensor, batch_size_per_rank: List[int], dim_sum_per_rank: List[int], dim_sum_per_rank_tensor: Optional[torch.Tensor] = None, cumsum_dim_sum_per_rank_tensor: Optional[torch.Tensor] = None, group: Optional[torch._C._distributed_c10d.ProcessGroup] = None) torchrec.distributed.types.Awaitable[torch.Tensor] ¶
Performs AlltoAll operation for a single pooled embedding tensor. Each process splits the input pooled embeddings tensor based on the world size, and then scatters the split list to all processes in the group. Then concatenates the received tensors from all processes in the group and returns a single output tensor.
- Parameters
a2a_pooled_embs_tensor (Tensor) – input pooled embeddings. Must be pooled together before passing into this function. Its shape is B x D_local_sum, where D_local_sum is the dimension sum of all the local embedding tables.
batch_size_per_rank (List[int]) – batch size in each rank.
dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the embedding in each rank.
dim_sum_per_rank_tensor (Optional[Tensor]) – the tensor version of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.
cumsum_dim_sum_per_rank_tensor (Optional[Tensor]) – cumulative sum of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.
group (Optional[dist.ProcessGroup]) – The process group to work on. If None, the default process group will be used.
- Returns
async work handle (Awaitable), which can be wait() later to get the resulting tensor.
- Return type
Awaitable[List[Tensor]]
Warning
alltoall_pooled is experimental and subject to change.
- torchrec.distributed.comm_ops.alltoall_sequence(a2a_sequence_embs_tensor: torch.Tensor, forward_recat_tensor: torch.Tensor, backward_recat_tensor: torch.Tensor, lengths_after_sparse_data_all2all: torch.Tensor, input_splits: List[int], output_splits: List[int], group: Optional[torch._C._distributed_c10d.ProcessGroup] = None) torchrec.distributed.types.Awaitable[torch.Tensor] ¶
Performs AlltoAll operation for sequence embeddings. Each process splits the input tensor based on the world size, and then scatters the split list to all processes in the group. Then concatenates the received tensors from all processes in the group and returns a single output tensor.
Note
AlltoAll operator for (T * B * L_i, D) tensors. Does not support mixed dimensions.
- Parameters
a2a_sequence_embs_tensor (Tensor) – input embeddings. Usually with the shape of (T * B * L_i, D), where B - batch size, T - number of embedding tables, D - embedding dimension.
forward_recat_tensor (Tensor) – recat tensor for forward.
backward_recat_tensor (Tensor) – recat tensor for backward.
lengths_after_sparse_data_all2all (Tensor) – lengths of sparse features after AlltoAll.
input_splits (Tensor) – input splits.
output_splits (Tensor) – output splits.
group (Optional[dist.ProcessGroup]) – The process group to work on. If None, the default process group will be used.
- Returns
async work handle (Awaitable), which can be wait() later to get the resulting tensor.
- Return type
Awaitable[List[Tensor]]
Warning
alltoall_sequence is experimental and subject to change.
- torchrec.distributed.comm_ops.alltoallv(inputs: List[torch.Tensor], out_split: Optional[List[int]] = None, per_rank_split_lengths: Optional[List[int]] = None, group: Optional[torch._C._distributed_c10d.ProcessGroup] = None) torchrec.distributed.types.Awaitable[List[torch.Tensor]] ¶
Performs alltoallv operation for a list of input embeddings. Each process scatters the list to all processes in the group.
- Parameters
input (List[Tensor]) – list of tensors to scatter, one per rank. The tensors in the list usually have different lengths.
out_split (Optional[List[int]]) – output split sizes (or dim_sum_per_rank), if not specified, we will use per_rank_split_lengths to construct a output split with the assumption that all the embs have the same dimension.
per_rank_split_lengths (Optional[List[int]]) – split lengths per rank. If not specified, the out_split must be specified.
group (Optional[dist.ProcessGroup]) – The process group to work on. If None, the default process group will be used.
- Returns
async work handle (Awaitable), which can be wait() later to get the resulting list of tensors.
- Return type
Awaitable[List[Tensor]]
Warning
alltoallv is experimental and subject to change.
- torchrec.distributed.comm_ops.reduce_scatter_pooled(inputs: List[torch.Tensor], group: Optional[torch._C._distributed_c10d.ProcessGroup] = None) torchrec.distributed.types.Awaitable[torch.Tensor] ¶
Performs reduce-scatter operation for a pooled embeddings tensor split into world size number of chunks. The result of the reduce operation gets scattered to all processes in the group. Then concatenates the received tensors from all processes in the group and returns a single output tensor.
- Parameters
inputs (List[Tensor]) – list of tensors to scatter, one per rank.
group (Optional[dist.ProcessGroup]) – The process group to work on. If None, the default process group will be used.
- Returns
async work handle (Awaitable), which can be wait() later to get the resulting tensor.
- Return type
Awaitable[List[Tensor]]
Warning
reduce_scatter_pooled is experimental and subject to change.
- torchrec.distributed.comm_ops.set_gradient_division(val: bool) None ¶
torchrec.distributed.dist_data¶
- class torchrec.distributed.dist_data.EmbeddingsAllToOne(device: torch.device, world_size: int, cat_dim: int)¶
Bases:
torch.nn.modules.module.Module
Merges the pooled/sequence embedding tensor on each device into single tensor.
- Parameters
device (torch.device) – device on which buffer will be allocated
world_size (int) – number of devices in the topology.
cat_dim (int) – which dimension you like to concate on. For pooled embedding it is 1; for sequence embedding it is 0.
- forward(tensors: List[torch.Tensor]) torchrec.distributed.types.Awaitable[torch.Tensor] ¶
Performs AlltoOne operation on pooled embeddings tensors.
- Parameters
tensors (List[torch.Tensor]) – list of pooled embedding tensors.
- Returns
awaitable of the merged pooled embeddings.
- Return type
Awaitable[torch.Tensor]
- training: bool¶
- class torchrec.distributed.dist_data.KJTAllToAll(pg: torch._C._distributed_c10d.ProcessGroup, splits: List[int], device: Optional[torch.device] = None, stagger: int = 1, variable_batch_size: bool = False)¶
Bases:
torch.nn.modules.module.Module
Redistributes KeyedJaggedTensor to a ProcessGroup according to splits.
Implementation utilizes AlltoAll collective as part of torch.distributed. Requires two collective calls, one to transmit final tensor lengths (to allocate correct space), and one to transmit actual sparse values.
- Parameters
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
splits (List[int]) – List of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.
device (Optional[torch.device]) – device on which buffers will be allocated.
stagger (int) – stagger value to apply to recat tensor, see _recat function for more detail.
variable_batch_size (bool) – variable batch size in each rank
Example:
keys=['A','B','C'] splits=[2,1] kjtA2A = KJTAllToAll(pg, splits, device) awaitable = kjtA2A(rank0_input) # where: # rank0_input is KeyedJaggedTensor holding # 0 1 2 # 'A' [A.V0] None [A.V1, A.V2] # 'B' None [B.V0] [B.V1] # 'C' [C.V0] [C.V1] None # rank1_input is KeyedJaggedTensor holding # 0 1 2 # 'A' [A.V3] [A.V4] None # 'B' None [B.V2] [B.V3, B.V4] # 'C' [C.V2] [C.V3] None rank0_output = awaitable.wait() # where: # rank0_output is KeyedJaggedTensor holding # 0 1 2 3 4 5 # 'A' [A.V0] None [A.V1, A.V2] [A.V3] [A.V4] None # 'B' None [B.V0] [B.V1] None [B.V2] [B.V3, B.V4] # rank1_output is KeyedJaggedTensor holding # 0 1 2 3 4 5 # 'C' [C.V0] [C.V1] None [C.V2] [C.V3] None
- forward(input: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torchrec.distributed.types.Awaitable[torchrec.distributed.dist_data.KJTAllToAllIndicesAwaitable] ¶
Sends input to relevant ProcessGroup ranks. First wait will have lengths results and issue indices/weights AlltoAll. Second wait will have indices/weights results.
- Parameters
input (KeyedJaggedTensor) – input KeyedJaggedTensor of values to distribute.
- Returns
awaitable of a KeyedJaggedTensor.
- Return type
- training: bool¶
- class torchrec.distributed.dist_data.KJTAllToAllIndicesAwaitable(pg: torch._C._distributed_c10d.ProcessGroup, input: torchrec.sparse.jagged_tensor.KeyedJaggedTensor, lengths: torch.Tensor, splits: List[int], keys: List[str], recat: torch.Tensor, in_lengths_per_worker: List[int], out_lengths_per_worker: List[int], batch_size_per_rank: List[int])¶
Bases:
torchrec.distributed.types.Awaitable
[torchrec.sparse.jagged_tensor.KeyedJaggedTensor
]Awaitable for KJT indices and weights All2All.
- Parameters
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
input (KeyedJaggedTensor) – Input KJT tensor.
lengths – Output lengths tensor
splits (List[int]) – List of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.
keys (List[str]) – KJT keys after AlltoAll.
recat (torch.Tensor) – recat tensor for reordering tensor order after AlltoAll.
in_lengths_per_worker (List[str]) – indices number of indices each rank will get.
- class torchrec.distributed.dist_data.KJTAllToAllLengthsAwaitable(pg: torch._C._distributed_c10d.ProcessGroup, input: torchrec.sparse.jagged_tensor.KeyedJaggedTensor, splits: List[int], keys: List[str], stagger: int, recat: torch.Tensor, variable_batch_size: bool = False)¶
Bases:
torchrec.distributed.types.Awaitable
[torchrec.distributed.dist_data.KJTAllToAllIndicesAwaitable
]Awaitable for KJT’s lengths AlltoAll.
wait() waits on lengths AlltoAll, then instantiates KJTAllToAllIndicesAwaitable awaitable where indices and weights AlltoAll will be issued.
- Parameters
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
input (KeyedJaggedTensor) – Input KJT tensor
splits (List[int]) – List of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.
keys (List[str]) – KJT keys after AlltoAll
recat (torch.Tensor) – recat tensor for reordering tensor order after AlltoAll.
- class torchrec.distributed.dist_data.KJTOneToAll(splits: List[int], world_size: int)¶
Bases:
torch.nn.modules.module.Module
Redistributes KeyedJaggedTensor to all devices.
Implementation utilizes OnetoAll function, which essentially P2P copies the feature to the devices.
- Parameters
splits (List[int]) – lengths of features to split the KeyJaggedTensor features into before copying them.
world_size (int) – number of devices in the topology.
recat (torch.Tensor) – recat tensor for reordering tensor order after AlltoAll.
- forward(kjt: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torchrec.distributed.types.Awaitable[List[torchrec.sparse.jagged_tensor.KeyedJaggedTensor]] ¶
Splits features first and then sends the slices to the corresponding devices.
- Parameters
kjt (KeyedJaggedTensor) – the input features.
- Returns
awaitable of KeyedJaggedTensor splits.
- Return type
Awaitable[List[KeyedJaggedTensor]]
- training: bool¶
- class torchrec.distributed.dist_data.PooledEmbeddingsAllToAll(pg: torch._C._distributed_c10d.ProcessGroup, dim_sum_per_rank: List[int], device: Optional[torch.device] = None, callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None)¶
Bases:
torch.nn.modules.module.Module
Shards batches and collects keys of tensor with a ProcessGroup according to dim_sum_per_rank.
Implementation utilizes alltoall_pooled operation.
- Parameters
pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.
dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the embedding in each rank.
device (Optional[torch.device]) – device on which buffers will be allocated.
callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]) –
Example:
dim_sum_per_rank = [2, 1] a2a = PooledEmbeddingsAllToAll(pg, dim_sum_per_rank, device) t0 = torch.rand((6, 2)) t1 = torch.rand((6, 1)) rank0_output = a2a(t0).wait() rank1_output = a2a(t1).wait() print(rank0_output.size()) # torch.Size([3, 3]) print(rank1_output.size()) # torch.Size([3, 3])
- property callbacks: List[Callable[[torch.Tensor], torch.Tensor]]¶
- forward(local_embs: torch.Tensor, batch_size_per_rank: Optional[List[int]] = None) torchrec.distributed.dist_data.PooledEmbeddingsAwaitable ¶
Performs AlltoAll pooled operation on pooled embeddings tensor.
- Parameters
local_embs (torch.Tensor) – tensor of values to distribute.
- Returns
awaitable of pooled embeddings.
- Return type
- training: bool¶
- class torchrec.distributed.dist_data.PooledEmbeddingsAwaitable(tensor_awaitable: torchrec.distributed.types.Awaitable[torch.Tensor])¶
Bases:
torchrec.distributed.types.Awaitable
[torch.Tensor
]Awaitable for pooled embeddings after collective operation.
- Parameters
tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.
- property callbacks: List[Callable[[torch.Tensor], torch.Tensor]]¶
- class torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter(pg: torch._C._distributed_c10d.ProcessGroup)¶
Bases:
torch.nn.modules.module.Module
The module class that wraps reduce-scatter communication primitive for pooled embedding communication in row-wise and twrw sharding.
For pooled embeddings, we have a local model-parallel output tensor with a layout of [num_buckets x batch_size, dimension]. We need to sum over num_buckets dimension across batches. We split tensor along the first dimension into equal chunks (tensor slices of different buckets) and reduce them into the output tensor and scatter the results for corresponding ranks.
The class returns the async Awaitable handle for pooled embeddings tensor. The reduce-scatter is only available for NCCL backend.
- Parameters
pg (dist.ProcessGroup) – The process group that the reduce-scatter communication happens within.
Example:
init_distributed(rank=rank, size=2, backend="nccl") pg = dist.new_group(backend="nccl") input = torch.randn(2 * 2, 2) m = PooledEmbeddingsReduceScatter(pg) output = m(input) tensor = output.wait()
- forward(local_embs: torch.Tensor) torchrec.distributed.dist_data.PooledEmbeddingsAwaitable ¶
Performs reduce scatter operation on pooled embeddings tensor.
- Parameters
local_embs (torch.Tensor) – tensor of shape [num_buckets x batch_size, dimension].
- Returns
awaitable of pooled embeddings of tensor of shape [batch_size, dimension].
- Return type
- training: bool¶
- class torchrec.distributed.dist_data.SequenceEmbeddingAllToAll(pg: torch._C._distributed_c10d.ProcessGroup, features_per_rank: List[int], device: Optional[torch.device] = None)¶
Bases:
torch.nn.modules.module.Module
Redistributes sequence embedding to a ProcessGroup according to splits.
- Parameters
pg (dist.ProcessGroup) – the process group that the AlltoAll communication happens within.
features_per_rank (List[int]) – List of number of features per rank.
device (Optional[torch.device]) – device on which buffers will be allocated.
Example:
init_distributed(rank=rank, size=2, backend="nccl") pg = dist.new_group(backend="nccl") features_per_rank = [4, 4] m = SequenceEmbeddingAllToAll(pg, features_per_rank) local_embs = torch.rand((6, 2)) sharding_ctx: SequenceShardingContext output = m( local_embs=local_embs, lengths=sharding_ctx.lengths_after_input_dist, input_splits=sharding_ctx.input_splits, output_splits=sharding_ctx.output_splits, unbucketize_permute_tensor=None, ) tensor = output.wait()
- forward(local_embs: torch.Tensor, lengths: torch.Tensor, input_splits: List[int], output_splits: List[int], unbucketize_permute_tensor: Optional[torch.Tensor] = None) torchrec.distributed.dist_data.SequenceEmbeddingsAwaitable ¶
Performs AlltoAll operation on sequence embeddings tensor.
- Parameters
local_embs (torch.Tensor) – input embeddings tensor.
lengths (torch.Tensor) – lengths of sparse features after AlltoAll.
input_splits (List[int]) – input splits of AlltoAll.
output_splits (List[int]) – output splits of AlltoAll.
unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of the KJT bucketize (for row-wise sharding only).
- Returns
SequenceEmbeddingsAwaitable
- training: bool¶
- class torchrec.distributed.dist_data.SequenceEmbeddingsAwaitable(tensor_awaitable: torchrec.distributed.types.Awaitable[torch.Tensor], unbucketize_permute_tensor: Optional[torch.Tensor], embedding_dim: int)¶
Bases:
torchrec.distributed.types.Awaitable
[torch.Tensor
]Awaitable for sequence embeddings after collective operation.
- Parameters
tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.
unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of KJT bucketize (for row-wise sharding only).
torchrec.distributed.embedding¶
- class torchrec.distributed.embedding.EmbeddingCollectionAwaitable(*args, **kwargs)¶
Bases:
torchrec.distributed.types.LazyAwaitable
[Dict
[str
,torchrec.sparse.jagged_tensor.JaggedTensor
]]
- class torchrec.distributed.embedding.EmbeddingCollectionContext(sharding_contexts: List[torchrec.distributed.sharding.sequence_sharding.SequenceShardingContext])¶
Bases:
torchrec.distributed.types.ShardedModuleContext
- record_stream(stream: torch.cuda.streams.Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- sharding_contexts: List[torchrec.distributed.sharding.sequence_sharding.SequenceShardingContext]¶
- class torchrec.distributed.embedding.EmbeddingCollectionSharder(fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.embedding_types.BaseEmbeddingSharder
[torchrec.modules.embedding_modules.EmbeddingCollection
]This implementation uses non-fused EmbeddingCollection
- property module_type: Type[torchrec.modules.embedding_modules.EmbeddingCollection]¶
- shard(module: torchrec.modules.embedding_modules.EmbeddingCollection, params: Dict[str, torchrec.distributed.types.ParameterSharding], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None) torchrec.distributed.embedding.ShardedEmbeddingCollection ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters
module (M) – module to shard.
params (Dict[str, ParameterSharding]) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns
sharded module implementation.
- Return type
ShardedModule[Any, Any, Any]
- shardable_parameters(module: torchrec.modules.embedding_modules.EmbeddingCollection) Dict[str, torch.nn.parameter.Parameter] ¶
List of parameters that can be sharded.
- sharding_types(compute_device_type: str) List[str] ¶
List of supported sharding types. See ShardingType for well-known examples.
- class torchrec.distributed.embedding.ShardedEmbeddingCollection(module: torchrec.modules.embedding_modules.EmbeddingCollection, table_name_to_parameter_sharding: Dict[str, torchrec.distributed.types.ParameterSharding], env: torchrec.distributed.types.ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None)¶
Bases:
torchrec.distributed.types.ShardedModule
[torchrec.distributed.embedding_types.SparseFeaturesList
,List
[torch.Tensor
],Dict
[str
,torch.Tensor
]],torchrec.optim.fused.FusedOptimizerModule
Sharded implementation of EmbeddingCollection. This is part of the public API to allow for manual data dist pipelining.
- compute(ctx: torchrec.distributed.types.ShardedModuleContext, dist_input: torchrec.distributed.embedding_types.SparseFeaturesList) List[torch.Tensor] ¶
- compute_and_output_dist(ctx: torchrec.distributed.types.ShardedModuleContext, input: torchrec.distributed.embedding_types.SparseFeaturesList) torchrec.distributed.types.LazyAwaitable[Dict[str, torch.Tensor]] ¶
In case of multiple output distributions it makes sense to override this method and initiate the output distibution as soon as the corresponding compute completes.
- create_context() torchrec.distributed.types.ShardedModuleContext ¶
- property fused_optimizer: torchrec.optim.keyed.KeyedOptimizer¶
- input_dist(ctx: torchrec.distributed.embedding.EmbeddingCollectionContext, features: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeaturesList] ¶
- load_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], strict: bool = True) torch.nn.modules.module._IncompatibleKeys ¶
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, torch.Tensor]] ¶
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters
prefix (str) – prefix to prepend to all buffer names.
recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
- Yields
(string, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_modules(memo: Optional[Set[torch.nn.modules.module.Module]] = None, prefix: str = '', remove_duplicate: bool = True) Iterator[Tuple[str, torch.nn.modules.module.Module]] ¶
Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
- Parameters
memo – a memo to store the set of modules already added to the result
prefix – a prefix that will be added to the name of the module
remove_duplicate – whether to remove the duplicated module instances in the result or not
- Yields
(string, Module) – Tuple of name and module
Note
Duplicate modules are returned only once. In the following example,
l
will be returned only once.Example:
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
- named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]] ¶
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields
(string, Parameter) – Tuple containing the name and parameter
Example:
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- output_dist(ctx: torchrec.distributed.types.ShardedModuleContext, output: List[torch.Tensor]) torchrec.distributed.types.LazyAwaitable[Dict[str, torch.Tensor]] ¶
- sharded_parameter_names(prefix: str = '') Iterator[str] ¶
- sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- torchrec.distributed.embedding.create_embedding_sharding(sharding_type: str, embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.EmbeddingSharding[torchrec.distributed.embedding_types.SparseFeatures, torch.Tensor] ¶
torchrec.distributed.embedding_lookup¶
- class torchrec.distributed.embedding_lookup.GroupedEmbeddingsLookup(grouped_configs: List[torchrec.distributed.embedding_types.GroupedEmbeddingConfig], pg: Optional[torch._C._distributed_c10d.ProcessGroup] = None, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.embedding_types.BaseEmbeddingLookup
[torchrec.distributed.embedding_types.SparseFeatures
,torch.Tensor
]- forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torch.Tensor ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- load_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], strict: bool = True) torch.nn.modules.module._IncompatibleKeys ¶
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.Tensor]] ¶
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters
prefix (str) – prefix to prepend to all buffer names.
recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
- Yields
(string, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]] ¶
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields
(string, Parameter) – Tuple containing the name and parameter
Example:
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- class torchrec.distributed.embedding_lookup.GroupedPooledEmbeddingsLookup(grouped_configs: List[torchrec.distributed.embedding_types.GroupedEmbeddingConfig], grouped_score_configs: List[torchrec.distributed.embedding_types.GroupedEmbeddingConfig], device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, pg: Optional[torch._C._distributed_c10d.ProcessGroup] = None, feature_processor: Optional[torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor] = None)¶
Bases:
torchrec.distributed.embedding_types.BaseEmbeddingLookup
[torchrec.distributed.embedding_types.SparseFeatures
,torch.Tensor
]- forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torch.Tensor ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- load_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], strict: bool = True) torch.nn.modules.module._IncompatibleKeys ¶
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.Tensor]] ¶
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters
prefix (str) – prefix to prepend to all buffer names.
recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
- Yields
(string, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]] ¶
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields
(string, Parameter) – Tuple containing the name and parameter
Example:
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- class torchrec.distributed.embedding_lookup.InferGroupedEmbeddingsLookup(grouped_configs_per_rank: List[List[torchrec.distributed.embedding_types.GroupedEmbeddingConfig]], world_size: int, fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.embedding_lookup.InferGroupedLookupMixin
,torchrec.distributed.embedding_types.BaseEmbeddingLookup
[torchrec.distributed.embedding_types.SparseFeaturesList
,List
[torch.Tensor
]]- training: bool¶
- class torchrec.distributed.embedding_lookup.InferGroupedLookupMixin¶
Bases:
abc.ABC
- forward(sparse_features: torchrec.distributed.embedding_types.SparseFeaturesList) List[torch.Tensor] ¶
- load_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], strict: bool = True) torch.nn.modules.module._IncompatibleKeys ¶
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.Tensor]] ¶
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
- class torchrec.distributed.embedding_lookup.InferGroupedPooledEmbeddingsLookup(grouped_configs_per_rank: List[List[torchrec.distributed.embedding_types.GroupedEmbeddingConfig]], grouped_score_configs_per_rank: List[List[torchrec.distributed.embedding_types.GroupedEmbeddingConfig]], world_size: int, fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.embedding_lookup.InferGroupedLookupMixin
,torchrec.distributed.embedding_types.BaseEmbeddingLookup
[torchrec.distributed.embedding_types.SparseFeaturesList
,List
[torch.Tensor
]]- training: bool¶
- class torchrec.distributed.embedding_lookup.MetaInferGroupedEmbeddingsLookup(grouped_configs: List[torchrec.distributed.embedding_types.GroupedEmbeddingConfig], device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.embedding_types.BaseEmbeddingLookup
[torchrec.distributed.embedding_types.SparseFeatures
,torch.Tensor
]meta embedding lookup module for inference since inference lookup has references for multiple TBE ops over all gpu workers. inference grouped embedding lookup module contains meta modules allocated over gpu workers.
- forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torch.Tensor ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- load_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], strict: bool = True) torch.nn.modules.module._IncompatibleKeys ¶
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.Tensor]] ¶
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters
prefix (str) – prefix to prepend to all buffer names.
recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
- Yields
(string, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]] ¶
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields
(string, Parameter) – Tuple containing the name and parameter
Example:
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- class torchrec.distributed.embedding_lookup.MetaInferGroupedPooledEmbeddingsLookup(grouped_configs: List[torchrec.distributed.embedding_types.GroupedEmbeddingConfig], grouped_score_configs: List[torchrec.distributed.embedding_types.GroupedEmbeddingConfig], device: Optional[torch.device] = None, feature_processor: Optional[torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor] = None, fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.embedding_types.BaseEmbeddingLookup
[torchrec.distributed.embedding_types.SparseFeatures
,torch.Tensor
]meta embedding bag lookup module for inference since inference lookup has references for multiple TBE ops over all gpu workers. inference grouped embedding bag lookup module contains meta modules allocated over gpu workers.
- forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torch.Tensor ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- load_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], strict: bool = True) torch.nn.modules.module._IncompatibleKeys ¶
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.Tensor]] ¶
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters
prefix (str) – prefix to prepend to all buffer names.
recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
- Yields
(string, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]] ¶
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields
(string, Parameter) – Tuple containing the name and parameter
Example:
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
torchrec.distributed.embedding_sharding¶
- class torchrec.distributed.embedding_sharding.BaseEmbeddingDist¶
Bases:
abc.ABC
,torch.nn.modules.module.Module
,Generic
[torchrec.distributed.embedding_sharding.T
]Converts output of EmbeddingLookup from model-parallel to data-parallel.
- abstract forward(local_embs: torchrec.distributed.embedding_sharding.T) torchrec.distributed.types.Awaitable[torch.Tensor] ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist¶
Bases:
abc.ABC
,torch.nn.modules.module.Module
,Generic
[torchrec.distributed.embedding_sharding.F
]Converts input from data-parallel to model-parallel.
- abstract forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torchrec.distributed.types.Awaitable[torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_sharding.F]] ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class torchrec.distributed.embedding_sharding.EmbeddingSharding¶
Bases:
abc.ABC
,Generic
[torchrec.distributed.embedding_sharding.F
,torchrec.distributed.embedding_sharding.T
]Used to implement different sharding types for EmbeddingBagCollection, e.g. table_wise.
- abstract create_input_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_sharding.F] ¶
- abstract create_lookup(device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor] = None) torchrec.distributed.embedding_types.BaseEmbeddingLookup[torchrec.distributed.embedding_sharding.F, torchrec.distributed.embedding_sharding.T] ¶
- abstract create_output_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseEmbeddingDist[torchrec.distributed.embedding_sharding.T] ¶
- abstract embedding_dims() List[int] ¶
- abstract embedding_names() List[str] ¶
- abstract embedding_shard_metadata() List[Optional[torch.distributed._shard.metadata.ShardMetadata]] ¶
- abstract id_list_feature_names() List[str] ¶
- abstract id_score_list_feature_names() List[str] ¶
- class torchrec.distributed.embedding_sharding.ListOfSparseFeaturesListAwaitable(awaitables: List[torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeaturesList]])¶
Bases:
torchrec.distributed.types.Awaitable
[torchrec.distributed.embedding_types.ListOfSparseFeaturesList
]This module handles the tables-wise sharding input features distribution for inference. For inference, we currently do not separate lengths from indices.
- Parameters
awaitables (List[Awaitable[SparseFeaturesList]]) – list of Awaitable of SparseFeaturesList.
- class torchrec.distributed.embedding_sharding.SparseFeaturesAllToAll(pg: torch._C._distributed_c10d.ProcessGroup, id_list_features_per_rank: List[int], id_score_list_features_per_rank: List[int], device: Optional[torch.device] = None, stagger: int = 1, variable_batch_size: bool = False)¶
Bases:
torch.nn.modules.module.Module
Redistributes sparse features to a ProcessGroup utilizing an AlltoAll collective.
- Parameters
pg (dist.ProcessGroup) – process group for AlltoAll communication.
id_list_features_per_rank (List[int]) – number of id list features to send to each rank.
id_score_list_features_per_rank (List[int]) – number of id score list features to send to each rank
device (Optional[torch.device]) – device on which buffers will be allocated.
stagger (int) – stagger value to apply to recat tensor, see _recat function for more detail.
variable_batch_size (bool) – variable batch size in each rank.
Example:
id_list_features_per_rank = [2, 1] id_score_list_features_per_rank = [1, 3] sfa2a = SparseFeaturesAllToAll( pg, id_list_features_per_rank, id_score_list_features_per_rank ) awaitable = sfa2a(rank0_input: SparseFeatures) # where: # rank0_input.id_list_features is KeyedJaggedTensor holding # 0 1 2 # 'A' [A.V0] None [A.V1, A.V2] # 'B' None [B.V0] [B.V1] # 'C' [C.V0] [C.V1] None # rank1_input.id_list_features is KeyedJaggedTensor holding # 0 1 2 # 'A' [A.V3] [A.V4] None # 'B' None [B.V2] [B.V3, B.V4] # 'C' [C.V2] [C.V3] None # rank0_input.id_score_list_features is KeyedJaggedTensor holding # 0 1 2 # 'A' [A.V0] None [A.V1, A.V2] # 'B' None [B.V0] [B.V1] # 'C' [C.V0] [C.V1] None # 'D' None [D.V0] None # rank1_input.id_score_list_features is KeyedJaggedTensor holding # 0 1 2 # 'A' [A.V3] [A.V4] None # 'B' None [B.V2] [B.V3, B.V4] # 'C' [C.V2] [C.V3] None # 'D' [D.V1] [D.V2] [D.V3, D.V4] rank0_output: SparseFeatures = awaitable.wait() # rank0_output.id_list_features is KeyedJaggedTensor holding # 0 1 2 3 4 5 # 'A' [A.V0] None [A.V1, A.V2] [A.V3] [A.V4] None # 'B' None [B.V0] [B.V1] None [B.V2] [B.V3, B.V4] # rank1_output.id_list_features is KeyedJaggedTensor holding # 0 1 2 3 4 5 # 'C' [C.V0] [C.V1] None [C.V2] [C.V3] None # rank0_output.id_score_list_features is KeyedJaggedTensor holding # 0 1 2 3 4 5 # 'A' [A.V0] None [A.V1, A.V2] [A.V3] [A.V4] None # rank1_output.id_score_list_features is KeyedJaggedTensor holding # 0 1 2 3 4 5 # 'B' None [B.V0] [B.V1] None [B.V2] [B.V3, B.V4] # 'C' [C.V0] [C.V1] None [C.V2] [C.V3] None # 'D None [D.V0] None [D.V1] [D.V2] [D.V3, D.V4]
- forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_sharding.SparseFeaturesIndicesAwaitable] ¶
Sends sparse features to relevant ProcessGroup ranks. Instantiates lengths AlltoAll. First wait will get lengths AlltoAll results, then issues indices AlltoAll. Second wait will get indices AlltoAll results.
- Parameters
sparse_features (SparseFeatures) – sparse features to redistribute.
- Returns
awaitable of SparseFeatures.
- Return type
- training: bool¶
- class torchrec.distributed.embedding_sharding.SparseFeaturesIndicesAwaitable(id_list_features_awaitable: Optional[torchrec.distributed.types.Awaitable[torchrec.sparse.jagged_tensor.KeyedJaggedTensor]], id_score_list_features_awaitable: Optional[torchrec.distributed.types.Awaitable[torchrec.sparse.jagged_tensor.KeyedJaggedTensor]])¶
Bases:
torchrec.distributed.types.Awaitable
[torchrec.distributed.embedding_types.SparseFeatures
]Awaitable of sparse features redistributed with AlltoAll collective.
- Parameters
id_list_features_awaitable (Optional[Awaitable[KeyedJaggedTensor]]) – awaitable of sharded id list features.
id_score_list_features_awaitable (Optional[Awaitable[KeyedJaggedTensor]]) – awaitable of sharded id score list features.
- class torchrec.distributed.embedding_sharding.SparseFeaturesLengthsAwaitable(id_list_features_awaitable: Optional[torchrec.distributed.types.Awaitable[torchrec.distributed.dist_data.KJTAllToAllIndicesAwaitable]], id_score_list_features_awaitable: Optional[torchrec.distributed.types.Awaitable[torchrec.distributed.dist_data.KJTAllToAllIndicesAwaitable]])¶
Bases:
torchrec.distributed.types.Awaitable
[torchrec.distributed.embedding_sharding.SparseFeaturesIndicesAwaitable
]Awaitable of sparse features indices distribution.
- Parameters
id_list_features_awaitable (Optional[Awaitable[KJTAllToAllIndicesAwaitable]]) – awaitable of sharded id list features indices AlltoAll. Waiting on this value will trigger indices AlltoAll (waiting again will yield final AlltoAll results).
id_score_list_features_awaitable – (Optional[Awaitable[KJTAllToAllIndicesAwaitable]]): awaitable of sharded id score list features indices AlltoAll. Waiting on this value will trigger indices AlltoAll (waiting again will yield the final AlltoAll results).
- class torchrec.distributed.embedding_sharding.SparseFeaturesListAwaitable(awaitables: List[torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeatures]])¶
Bases:
torchrec.distributed.types.Awaitable
[torchrec.distributed.embedding_types.SparseFeaturesList
]Awaitable of SparseFeaturesList.
- Parameters
awaitables (List[Awaitable[SparseFeatures]]) – list of Awaitable of sparse features.
- class torchrec.distributed.embedding_sharding.SparseFeaturesListIndicesAwaitable(awaitables: List[torchrec.distributed.types.Awaitable[torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeatures]]])¶
Bases:
torchrec.distributed.types.Awaitable
[List
[torchrec.distributed.types.Awaitable
[torchrec.distributed.embedding_types.SparseFeatures
]]]Handles the first wait for a list of two-layer awaitables of SparseFeatures. Wait on this module will get lengths AlltoAll results for each SparseFeatures, and instantiate its indices AlltoAll.
- Parameters
awaitables (List[Awaitable[Awaitable[SparseFeatures]]]) – list of Awaitable of Awaitable sparse features.
- class torchrec.distributed.embedding_sharding.SparseFeaturesOneToAll(id_list_features_per_rank: List[int], id_score_list_features_per_rank: List[int], world_size: int)¶
Bases:
torch.nn.modules.module.Module
Redistributes sparse features to all devices.
- Parameters
id_list_features_per_rank (List[int]) – number of id list features to send to each rank.
id_score_list_features_per_rank (List[int]) – number of id score list features to send to each rank
world_size (int) – number of devices in the topology.
- forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeaturesList] ¶
Performs OnetoAll operation on sparse features.
- Parameters
sparse_features (SparseFeatures) – sparse features to redistribute.
- Returns
awaitable of SparseFeatures.
- Return type
- training: bool¶
- torchrec.distributed.embedding_sharding.bucketize_kjt_before_all2all(kjt: torchrec.sparse.jagged_tensor.KeyedJaggedTensor, num_buckets: int, block_sizes: torch.Tensor, output_permute: bool = False, bucketize_pos: bool = False) Tuple[torchrec.sparse.jagged_tensor.KeyedJaggedTensor, Optional[torch.Tensor]] ¶
Bucketizes the values in KeyedJaggedTensor into num_buckets buckets, lengths are readjusted based on the bucketization results.
Note: This function should be used only for row-wise sharding before calling SparseFeaturesAllToAll.
- Parameters
num_buckets (int) – number of buckets to bucketize the values into.
block_sizes – (torch.Tensor): bucket sizes for the keyed dimension.
output_permute (bool) – output the memory location mapping from the unbucketized values to bucketized values or not.
bucketize_pos (bool) – output the changed position of the bucketized values or not.
- Returns
the bucketized KeyedJaggedTensor and the optional permute mapping from the unbucketized values to bucketized value.
- Return type
Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]
- torchrec.distributed.embedding_sharding.group_tables(tables_per_rank: List[List[torchrec.distributed.embedding_types.ShardedEmbeddingTable]]) Tuple[List[List[torchrec.distributed.embedding_types.GroupedEmbeddingConfig]], List[List[torchrec.distributed.embedding_types.GroupedEmbeddingConfig]]] ¶
Groups tables by DataType, PoolingType, Weighted, and EmbeddingComputeKernel.
- Parameters
tables_per_rank (List[List[ShardedEmbeddingTable]]) – list of sharding embedding tables per rank.
- Returns
per rank list of GroupedEmbeddingConfig for unscored and scored features.
- Return type
Tuple[List[List[GroupedEmbeddingConfig]], List[List[GroupedEmbeddingConfig]]]
torchrec.distributed.embedding_types¶
- class torchrec.distributed.embedding_types.BaseEmbeddingLookup¶
Bases:
abc.ABC
,torch.nn.modules.module.Module
,Generic
[torchrec.distributed.embedding_types.F
,torchrec.distributed.embedding_types.T
]Interface implemented by different embedding implementations: e.g. one, which relies on nn.EmbeddingBag or table-batched one, etc.
- abstract forward(sparse_features: torchrec.distributed.embedding_types.F) torchrec.distributed.embedding_types.T ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str] ¶
- training: bool¶
- class torchrec.distributed.embedding_types.BaseEmbeddingSharder(fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.types.ModuleSharder
[torchrec.distributed.embedding_types.M
]- compute_kernels(sharding_type: str, compute_device_type: str) List[str] ¶
List of supported compute kernels for a given sharding type and compute device.
- property fused_params: Optional[Dict[str, Any]]¶
- sharding_types(compute_device_type: str) List[str] ¶
List of supported sharding types. See ShardingType for well-known examples.
- storage_usage(tensor: torch.Tensor, compute_device_type: str, compute_kernel: str) Dict[str, int] ¶
List of system resources and corresponding usage given a compute device and compute kernel
- class torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor¶
Bases:
torch.nn.modules.module.Module
Abstract base class for grouped feature processor
- abstract forward(features: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torchrec.sparse.jagged_tensor.KeyedJaggedTensor ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str] ¶
- training: bool¶
- class torchrec.distributed.embedding_types.BaseQuantEmbeddingSharder(fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.types.ModuleSharder
[torchrec.distributed.embedding_types.M
]- compute_kernels(sharding_type: str, compute_device_type: str) List[str] ¶
List of supported compute kernels for a given sharding type and compute device.
- property fused_params: Optional[Dict[str, Any]]¶
- sharding_types(compute_device_type: str) List[str] ¶
List of supported sharding types. See ShardingType for well-known examples.
- storage_usage(tensor: torch.Tensor, compute_device_type: str, compute_kernel: str) Dict[str, int] ¶
List of system resources and corresponding usage given a compute device and compute kernel
- class torchrec.distributed.embedding_types.EmbeddingAttributes(compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel = <EmbeddingComputeKernel.DENSE: 'dense'>)¶
Bases:
object
- compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel = 'dense'¶
- class torchrec.distributed.embedding_types.EmbeddingComputeKernel(value)¶
Bases:
enum.Enum
An enumeration.
- BATCHED_DENSE = 'batched_dense'¶
- BATCHED_FUSED = 'batched_fused'¶
- BATCHED_FUSED_UVM = 'batched_fused_uvm'¶
- BATCHED_FUSED_UVM_CACHING = 'batched_fused_uvm_caching'¶
- BATCHED_QUANT = 'batched_quant'¶
- BATCHED_QUANT_UVM = 'batched_quant_uvm'¶
- BATCHED_QUANT_UVM_CACHING = 'batched_quant_uvm_caching'¶
- DENSE = 'dense'¶
- SPARSE = 'sparse'¶
- class torchrec.distributed.embedding_types.GroupedEmbeddingConfig(data_type: torchrec.modules.embedding_configs.DataType, pooling: torchrec.modules.embedding_configs.PoolingType, is_weighted: bool, has_feature_processor: bool, compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel, embedding_tables: List[torchrec.distributed.embedding_types.ShardedEmbeddingTable])¶
Bases:
object
- compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel¶
- data_type: torchrec.modules.embedding_configs.DataType¶
- dim_sum() int ¶
- embedding_dims() List[int] ¶
- embedding_names() List[str] ¶
- embedding_shard_metadata() List[Optional[torch.distributed._shard.metadata.ShardMetadata]] ¶
- embedding_tables: List[torchrec.distributed.embedding_types.ShardedEmbeddingTable]¶
- feature_hash_sizes() List[int] ¶
- feature_names() List[str] ¶
- has_feature_processor: bool¶
- is_weighted: bool¶
- num_features() int ¶
- class torchrec.distributed.embedding_types.ListOfSparseFeaturesList(features: List[torchrec.distributed.embedding_types.SparseFeaturesList])¶
Bases:
torchrec.streamable.Multistreamable
- record_stream(stream: torch.cuda.streams.Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- class torchrec.distributed.embedding_types.OptimType(value)¶
Bases:
enum.Enum
An enumeration.
- ADAGRAD = 'ADAGRAD'¶
- ADAM = 'ADAM'¶
- LAMB = 'LAMB'¶
- LARS_SGD = 'LARS_SGD'¶
- PARTIAL_ROWWISE_ADAM = 'PARTIAL_ROWWISE_ADAM'¶
- PARTIAL_ROWWISE_LAMB = 'PARTIAL_ROWWISE_LAMB'¶
- ROWWISE_ADAGRAD = 'ROWWISE_ADAGRAD'¶
- SGD = 'SGD'¶
- class torchrec.distributed.embedding_types.ShardedConfig(local_rows: int = 0, local_cols: int = 0)¶
Bases:
object
- local_cols: int = 0¶
- local_rows: int = 0¶
- class torchrec.distributed.embedding_types.ShardedEmbeddingTable(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.modules.embedding_configs.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Optional[float] = None, weight_init_min: Optional[float] = None, pooling: torchrec.modules.embedding_configs.PoolingType = <PoolingType.SUM: 'SUM'>, is_weighted: bool = False, has_feature_processor: bool = False, embedding_names: List[str] = <factory>, compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel = <EmbeddingComputeKernel.DENSE: 'dense'>, local_rows: int = 0, local_cols: int = 0, local_metadata: Optional[torch.distributed._shard.metadata.ShardMetadata] = None, global_metadata: Optional[torch.distributed._shard.sharded_tensor.metadata.ShardedTensorMetadata] = None)¶
Bases:
torchrec.distributed.embedding_types.ShardedMetaConfig
,torchrec.distributed.embedding_types.EmbeddingAttributes
,torchrec.modules.embedding_configs.EmbeddingTableConfig
- class torchrec.distributed.embedding_types.ShardedMetaConfig(local_rows: int = 0, local_cols: int = 0, local_metadata: Optional[torch.distributed._shard.metadata.ShardMetadata] = None, global_metadata: Optional[torch.distributed._shard.sharded_tensor.metadata.ShardedTensorMetadata] = None)¶
Bases:
torchrec.distributed.embedding_types.ShardedConfig
- global_metadata: Optional[torch.distributed._shard.sharded_tensor.metadata.ShardedTensorMetadata] = None¶
- local_metadata: Optional[torch.distributed._shard.metadata.ShardMetadata] = None¶
- class torchrec.distributed.embedding_types.SparseFeatures(id_list_features: Optional[torchrec.sparse.jagged_tensor.KeyedJaggedTensor] = None, id_score_list_features: Optional[torchrec.sparse.jagged_tensor.KeyedJaggedTensor] = None)¶
Bases:
torchrec.streamable.Multistreamable
- id_list_features: Optional[torchrec.sparse.jagged_tensor.KeyedJaggedTensor] = None¶
- id_score_list_features: Optional[torchrec.sparse.jagged_tensor.KeyedJaggedTensor] = None¶
- record_stream(stream: torch.cuda.streams.Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- class torchrec.distributed.embedding_types.SparseFeaturesList(features: List[torchrec.distributed.embedding_types.SparseFeatures])¶
Bases:
torchrec.streamable.Multistreamable
- record_stream(stream: torch.cuda.streams.Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- torchrec.distributed.embedding_types.compute_kernel_to_embedding_location(compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel) fbgemm_gpu.split_table_batched_embeddings_ops.EmbeddingLocation ¶
torchrec.distributed.embeddingbag¶
- class torchrec.distributed.embeddingbag.EmbeddingAwaitable(*args, **kwargs)¶
Bases:
torchrec.distributed.types.LazyAwaitable
[torch.Tensor
]
- class torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable(*args, **kwargs)¶
Bases:
torchrec.distributed.types.LazyAwaitable
[torchrec.sparse.jagged_tensor.KeyedTensor
]
- class torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder(fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.embedding_types.BaseEmbeddingSharder
[torchrec.modules.embedding_modules.EmbeddingBagCollection
]This implementation uses non-fused EmbeddingBagCollection
- property module_type: Type[torchrec.modules.embedding_modules.EmbeddingBagCollection]¶
- shard(module: torchrec.modules.embedding_modules.EmbeddingBagCollection, params: Dict[str, torchrec.distributed.types.ParameterSharding], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None) torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters
module (M) – module to shard.
params (Dict[str, ParameterSharding]) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns
sharded module implementation.
- Return type
ShardedModule[Any, Any, Any]
- shardable_parameters(module: torchrec.modules.embedding_modules.EmbeddingBagCollection) Dict[str, torch.nn.parameter.Parameter] ¶
List of parameters that can be sharded.
- class torchrec.distributed.embeddingbag.EmbeddingBagSharder(fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.embedding_types.BaseEmbeddingSharder
[torch.nn.modules.sparse.EmbeddingBag
]This implementation uses non-fused nn.EmbeddingBag
- property module_type: Type[torch.nn.modules.sparse.EmbeddingBag]¶
- shard(module: torch.nn.modules.sparse.EmbeddingBag, params: Dict[str, torchrec.distributed.types.ParameterSharding], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None) torchrec.distributed.embeddingbag.ShardedEmbeddingBag ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters
module (M) – module to shard.
params (Dict[str, ParameterSharding]) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns
sharded module implementation.
- Return type
ShardedModule[Any, Any, Any]
- shardable_parameters(module: torch.nn.modules.sparse.EmbeddingBag) Dict[str, torch.nn.parameter.Parameter] ¶
List of parameters that can be sharded.
- class torchrec.distributed.embeddingbag.ShardedEmbeddingBag(module: torch.nn.modules.sparse.EmbeddingBag, table_name_to_parameter_sharding: Dict[str, torchrec.distributed.types.ParameterSharding], env: torchrec.distributed.types.ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None)¶
Bases:
torchrec.distributed.types.ShardedModule
[torchrec.distributed.embedding_types.SparseFeatures
,torch.Tensor
,torch.Tensor
],torchrec.optim.fused.FusedOptimizerModule
Sharded implementation of nn.EmbeddingBag. This is part of the public API to allow for manual data dist pipelining.
- compute(ctx: torchrec.distributed.types.ShardedModuleContext, dist_input: torchrec.distributed.embedding_types.SparseFeatures) torch.Tensor ¶
- property fused_optimizer: torchrec.optim.keyed.KeyedOptimizer¶
- input_dist(ctx: torchrec.distributed.types.ShardedModuleContext, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, per_sample_weights: Optional[torch.Tensor] = None) torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeatures] ¶
- load_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], strict: bool = True) torch.nn.modules.module._IncompatibleKeys ¶
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.Tensor]] ¶
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters
prefix (str) – prefix to prepend to all buffer names.
recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
- Yields
(string, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_modules(memo: Optional[Set[torch.nn.modules.module.Module]] = None, prefix: str = '', remove_duplicate: bool = True) Iterator[Tuple[str, torch.nn.modules.module.Module]] ¶
Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
- Parameters
memo – a memo to store the set of modules already added to the result
prefix – a prefix that will be added to the name of the module
remove_duplicate – whether to remove the duplicated module instances in the result or not
- Yields
(string, Module) – Tuple of name and module
Note
Duplicate modules are returned only once. In the following example,
l
will be returned only once.Example:
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]] ¶
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields
(string, Parameter) – Tuple containing the name and parameter
Example:
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- output_dist(ctx: torchrec.distributed.types.ShardedModuleContext, output: torch.Tensor) torchrec.distributed.types.LazyAwaitable[torch.Tensor] ¶
- sharded_parameter_names(prefix: str = '') Iterator[str] ¶
- sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- class torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection(module: torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, torchrec.distributed.types.ParameterSharding], env: torchrec.distributed.types.ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None)¶
Bases:
torchrec.distributed.types.ShardedModule
[torchrec.distributed.embedding_types.SparseFeaturesList
,List
[torch.Tensor
],torchrec.sparse.jagged_tensor.KeyedTensor
],torchrec.optim.fused.FusedOptimizerModule
Sharded implementation of EmbeddingBagCollection. This is part of the public API to allow for manual data dist pipelining.
- compute(ctx: torchrec.distributed.types.ShardedModuleContext, dist_input: torchrec.distributed.embedding_types.SparseFeaturesList) List[torch.Tensor] ¶
- compute_and_output_dist(ctx: torchrec.distributed.types.ShardedModuleContext, input: torchrec.distributed.embedding_types.SparseFeaturesList) torchrec.distributed.types.LazyAwaitable[torchrec.sparse.jagged_tensor.KeyedTensor] ¶
In case of multiple output distributions it makes sense to override this method and initiate the output distibution as soon as the corresponding compute completes.
- property fused_optimizer: torchrec.optim.keyed.KeyedOptimizer¶
- input_dist(ctx: torchrec.distributed.types.ShardedModuleContext, features: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeaturesList] ¶
- load_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], strict: bool = True) torch.nn.modules.module._IncompatibleKeys ¶
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.Tensor]] ¶
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters
prefix (str) – prefix to prepend to all buffer names.
recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
- Yields
(string, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_modules(memo: Optional[Set[torch.nn.modules.module.Module]] = None, prefix: str = '', remove_duplicate: bool = True) Iterator[Tuple[str, torch.nn.modules.module.Module]] ¶
Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
- Parameters
memo – a memo to store the set of modules already added to the result
prefix – a prefix that will be added to the name of the module
remove_duplicate – whether to remove the duplicated module instances in the result or not
- Yields
(string, Module) – Tuple of name and module
Note
Duplicate modules are returned only once. In the following example,
l
will be returned only once.Example:
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]] ¶
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields
(string, Parameter) – Tuple containing the name and parameter
Example:
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- output_dist(ctx: torchrec.distributed.types.ShardedModuleContext, output: List[torch.Tensor]) torchrec.distributed.types.LazyAwaitable[torchrec.sparse.jagged_tensor.KeyedTensor] ¶
- sharded_parameter_names(prefix: str = '') Iterator[str] ¶
- sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- torchrec.distributed.embeddingbag.create_embedding_bag_sharding(sharding_type: str, embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None, permute_embeddings: bool = False) torchrec.distributed.embedding_sharding.EmbeddingSharding[torchrec.distributed.embedding_types.SparseFeatures, torch.Tensor] ¶
- torchrec.distributed.embeddingbag.create_embedding_configs_by_sharding(module: torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, torchrec.distributed.types.ParameterSharding], prefix: str) Dict[str, List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]]] ¶
- torchrec.distributed.embeddingbag.filter_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], name: str) collections.OrderedDict[str, torch.Tensor] ¶
- torchrec.distributed.embeddingbag.replace_placement_with_meta_device(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]]) None ¶
Placement device and tensor device could be unmatched in some scenarios, e.g. passing meta device to DMP and passing cuda to EmbeddingShardingPlanner. We need to make device consistent after getting sharding planner.
torchrec.distributed.grouped_position_weighted¶
- class torchrec.distributed.grouped_position_weighted.GroupedPositionWeightedModule(max_feature_lengths: Dict[str, int], device: Optional[torch.device] = None)¶
Bases:
torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor
- forward(features: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torchrec.sparse.jagged_tensor.KeyedJaggedTensor ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.Tensor]] ¶
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters
prefix (str) – prefix to prepend to all buffer names.
recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
- Yields
(string, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]] ¶
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields
(string, Parameter) – Tuple containing the name and parameter
Example:
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
torchrec.distributed.model_parallel¶
- class torchrec.distributed.model_parallel.DataParallelWrapper¶
Bases:
abc.ABC
Interface implemented by custom data parallel wrappers.
- abstract wrap(dmp: torchrec.distributed.model_parallel.DistributedModelParallel, env: torchrec.distributed.types.ShardingEnv, device: torch.device) None ¶
- class torchrec.distributed.model_parallel.DefaultDataParallelWrapper¶
Bases:
torchrec.distributed.model_parallel.DataParallelWrapper
Default data parallel wrapper, which applies data parallel to all unsharded modules.
- wrap(dmp: torchrec.distributed.model_parallel.DistributedModelParallel, env: torchrec.distributed.types.ShardingEnv, device: torch.device) None ¶
- class torchrec.distributed.model_parallel.DistributedModelParallel(module: torch.nn.modules.module.Module, env: Optional[torchrec.distributed.types.ShardingEnv] = None, device: Optional[torch.device] = None, plan: Optional[torchrec.distributed.types.ShardingPlan] = None, sharders: Optional[List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[torchrec.distributed.model_parallel.DataParallelWrapper] = None)¶
Bases:
torch.nn.modules.module.Module
,torchrec.optim.fused.FusedOptimizerModule
Entry point to model parallelism.
- Parameters
module (nn.Module) – module to wrap.
env (Optional[ShardingEnv]) – sharding environment that has the process group.
device (Optional[torch.device]) – compute device, defaults to cpu.
plan (Optional[ShardingPlan]) – plan to use when sharding, defaults to EmbeddingShardingPlanner.collective_plan().
sharders (Optional[List[ModuleSharder[nn.Module]]]) – ModuleSharders available to shard with, defaults to EmbeddingBagCollectionSharder().
init_data_parallel (bool) – data-parallel modules can be lazy, i.e. they delay parameter initialization until the first forward pass. Pass True to delay initialization of data parallel modules. Do first forward pass and then call DistributedModelParallel.init_data_parallel().
init_parameters (bool) – initialize parameters for modules still on meta device.
data_parallel_wrapper (Optional[DataParallelWrapper]) – custom wrapper for data parallel modules.
Example:
@torch.no_grad() def init_weights(m): if isinstance(m, nn.Linear): m.weight.fill_(1.0) elif isinstance(m, EmbeddingBagCollection): for param in m.parameters(): init.kaiming_normal_(param) m = MyModel(device='meta') m = DistributedModelParallel(m) m.apply(init_weights)
- bare_named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]] ¶
- copy(device: torch.device) torchrec.distributed.model_parallel.DistributedModelParallel ¶
Recursively copy submodules to new device by calling per-module customized copy process, since some modules needs to use the original references (like ShardedModule for inference).
- forward(*args, **kwargs) Any ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- property fused_optimizer: torchrec.optim.keyed.KeyedOptimizer¶
- init_data_parallel() None ¶
See init_data_parallel c-tor argument for usage. It’s safe to call this method multiple times.
- load_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], prefix: str = '', strict: bool = True) torch.nn.modules.module._IncompatibleKeys ¶
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- property module: torch.nn.modules.module.Module¶
Property to directly access sharded module, which will not be wrapped in DDP, FSDP, DMP, or any other parallelism wrappers.
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.Tensor]] ¶
Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters
prefix (str) – prefix to prepend to all buffer names.
recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
- Yields
(string, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]] ¶
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields
(string, Parameter) – Tuple containing the name and parameter
Example:
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- property plan: torchrec.distributed.types.ShardingPlan¶
- sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- torchrec.distributed.model_parallel.get_default_sharders() List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]] ¶
- torchrec.distributed.model_parallel.get_module(module: torch.nn.modules.module.Module) torch.nn.modules.module.Module ¶
Unwraps DMP module.
Does not unwrap data parallel wrappers (i.e. DDP/FSDP), so overriding implementations by the wrappers can be used.
- torchrec.distributed.model_parallel.get_unwrapped_module(module: torch.nn.modules.module.Module) torch.nn.modules.module.Module ¶
Unwraps module wrapped by DMP, DDP, or FSDP.
torchrec.distributed.quant_embeddingbag¶
- class torchrec.distributed.quant_embeddingbag.QuantEmbeddingBagCollectionSharder(fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.embedding_types.BaseQuantEmbeddingSharder
[torchrec.quant.embedding_modules.EmbeddingBagCollection
]- property module_type: Type[torchrec.quant.embedding_modules.EmbeddingBagCollection]¶
- shard(module: torchrec.quant.embedding_modules.EmbeddingBagCollection, params: Dict[str, torchrec.distributed.types.ParameterSharding], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None) torchrec.distributed.quant_embeddingbag.ShardedQuantEmbeddingBagCollection ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters
module (M) – module to shard.
params (Dict[str, ParameterSharding]) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns
sharded module implementation.
- Return type
ShardedModule[Any, Any, Any]
- shardable_parameters(module: torchrec.quant.embedding_modules.EmbeddingBagCollection) Dict[str, torch.nn.parameter.Parameter] ¶
List of parameters that can be sharded.
- class torchrec.distributed.quant_embeddingbag.ShardedQuantEmbeddingBagCollection(module: torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, torchrec.distributed.types.ParameterSharding], env: torchrec.distributed.types.ShardingEnv, fused_params: Optional[Dict[str, Any]] = None)¶
Bases:
torchrec.distributed.types.ShardedModule
[torchrec.distributed.embedding_types.ListOfSparseFeaturesList
,List
[List
[torch.Tensor
]],torchrec.sparse.jagged_tensor.KeyedTensor
]Sharded implementation of EmbeddingBagCollection. This is part of the public API to allow for manual data dist pipelining.
- compute(ctx: torchrec.distributed.types.ShardedModuleContext, dist_input: torchrec.distributed.embedding_types.ListOfSparseFeaturesList) List[List[torch.Tensor]] ¶
- compute_and_output_dist(ctx: torchrec.distributed.types.ShardedModuleContext, input: torchrec.distributed.embedding_types.ListOfSparseFeaturesList) torchrec.distributed.types.LazyAwaitable[torchrec.sparse.jagged_tensor.KeyedTensor] ¶
In case of multiple output distributions it makes sense to override this method and initiate the output distibution as soon as the corresponding compute completes.
- copy(device: torch.device) torch.nn.modules.module.Module ¶
- input_dist(ctx: torchrec.distributed.types.ShardedModuleContext, features: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.ListOfSparseFeaturesList] ¶
- load_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], strict: bool = True) torch.nn.modules.module._IncompatibleKeys ¶
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- output_dist(ctx: torchrec.distributed.types.ShardedModuleContext, output: List[List[torch.Tensor]]) torchrec.distributed.types.LazyAwaitable[torchrec.sparse.jagged_tensor.KeyedTensor] ¶
- state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any] ¶
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns
a dictionary containing a whole state of the module
- Return type
dict
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- training: bool¶
- torchrec.distributed.quant_embeddingbag.create_infer_embedding_bag_sharding(sharding_type: str, embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv) torchrec.distributed.embedding_sharding.EmbeddingSharding[torchrec.distributed.embedding_types.SparseFeaturesList, List[torch.Tensor]] ¶
torchrec.distributed.train_pipeline¶
- class torchrec.distributed.train_pipeline.ArgInfo(input_attrs: List[str], is_getitems: List[bool], name: Optional[str])¶
Bases:
object
- input_attrs: List[str]¶
- is_getitems: List[bool]¶
- name: Optional[str]¶
- class torchrec.distributed.train_pipeline.PipelinedForward(name: str, args: List[torchrec.distributed.train_pipeline.ArgInfo], module: torchrec.distributed.types.ShardedModule[torchrec.distributed.train_pipeline.DistIn, torchrec.distributed.train_pipeline.DistOut, torchrec.distributed.train_pipeline.Out], context: torchrec.distributed.train_pipeline.TrainPipelineContext, dist_stream: Optional[torch.cuda.streams.Stream])¶
Bases:
Generic
[torchrec.distributed.train_pipeline.DistIn
,torchrec.distributed.train_pipeline.DistOut
,torchrec.distributed.train_pipeline.Out
]- property args: List[torchrec.distributed.train_pipeline.ArgInfo]¶
- property name: str¶
- class torchrec.distributed.train_pipeline.Tracer(unsharded_module_names: List[str])¶
Bases:
torch.fx._symbolic_trace.Tracer
- graph: torch.fx.graph.Graph¶
- is_leaf_module(m: torch.nn.modules.module.Module, module_qualified_name: str) bool ¶
A method to specify whether a given
nn.Module
is a “leaf” module.Leaf modules are the atomic units that appear in the IR, referenced by
call_module
calls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.- Parameters
m (Module) – The module being queried about
module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule
foo
contains submodulebar
, which contains submodulebaz
, that module will appear with the qualified namefoo.bar.baz
here.
Note
Backwards-compatibility for this API is guaranteed.
- proxy_buffer_attributes: bool = False¶
- class torchrec.distributed.train_pipeline.TrainPipeline¶
Bases:
abc.ABC
,Generic
[torchrec.distributed.train_pipeline.In
,torchrec.distributed.train_pipeline.Out
]- abstract progress(dataloader_iter: Iterator[torchrec.distributed.train_pipeline.In]) torchrec.distributed.train_pipeline.Out ¶
- class torchrec.distributed.train_pipeline.TrainPipelineBase(model: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, device: torch.device)¶
Bases:
torchrec.distributed.train_pipeline.TrainPipeline
[torchrec.distributed.train_pipeline.In
,torchrec.distributed.train_pipeline.Out
]This class runs training iterations using a pipeline of two stages, each as a CUDA stream, namely, the current (default) stream and self._memcpy_stream. For each iteration, self._memcpy_stream moves the input from host (CPU) memory to GPU memory, and the default stream runs forward, backward, and optimization.
- progress(dataloader_iter: Iterator[torchrec.distributed.train_pipeline.In]) torchrec.distributed.train_pipeline.Out ¶
- class torchrec.distributed.train_pipeline.TrainPipelineContext(input_dist_requests: Dict[str, torchrec.distributed.types.Awaitable[Any]] = <factory>, module_contexts: Dict[str, torchrec.distributed.types.ShardedModuleContext] = <factory>)¶
Bases:
object
- input_dist_requests: Dict[str, torchrec.distributed.types.Awaitable[Any]]¶
- module_contexts: Dict[str, torchrec.distributed.types.ShardedModuleContext]¶
- class torchrec.distributed.train_pipeline.TrainPipelineSparseDist(model: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, device: torch.device)¶
Bases:
torchrec.distributed.train_pipeline.TrainPipeline
[torchrec.distributed.train_pipeline.In
,torchrec.distributed.train_pipeline.Out
]This pipeline overlaps device transfer, and ShardedModule.input_dist() with forward and backward. This helps hide the all2all latency while preserving the training forward / backward ordering.
stage 3: forward, backward - uses default CUDA stream stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream stage 1: device transfer - uses memcpy CUDA stream
ShardedModule.input_dist() is only done for top-level modules in the call graph. To be considered a top-level module, a module can only depend on ‘getattr’ calls on input.
Input model must be symbolically traceable with the exception of ShardedModule and DistributedDataParallel modules.
- progress(dataloader_iter: Iterator[torchrec.distributed.train_pipeline.In]) torchrec.distributed.train_pipeline.Out ¶
torchrec.distributed.types¶
- class torchrec.distributed.types.Awaitable¶
Bases:
abc.ABC
,Generic
[torchrec.distributed.types.W
]- property callbacks: List[Callable[[torchrec.distributed.types.W], torchrec.distributed.types.W]]¶
- wait() torchrec.distributed.types.W ¶
- class torchrec.distributed.types.ComputeKernel(value)¶
Bases:
enum.Enum
An enumeration.
- DEFAULT = 'default'¶
- class torchrec.distributed.types.EmptyContext¶
Bases:
torchrec.distributed.types.ShardedModuleContext
- record_stream(stream: torch.cuda.streams.Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- class torchrec.distributed.types.GenericMeta¶
Bases:
type
- class torchrec.distributed.types.LazyAwaitable(*args, **kwargs)¶
Bases:
torchrec.distributed.types.Awaitable
[torchrec.distributed.types.W
]The LazyAwaitable type which exposes a wait() API, concrete types can control how to initialize and how the wait() behavior should be in order to achieve specific async operation.
This base LazyAwaitable type is a “lazy” async type, which means it will delay wait() as late as possible, see details in __torch_function__ below. This could help the model automatically enable computation and communication overlap, model author doesn’t need to manually call wait() if the results is used by a pytorch function, or by other python operations (NOTE: need to implement corresponding magic methods like __getattr__ below)
Some caveats:
This works with Pytorch functions, but not any generic method, if you would like to do arbitary python operations, you need to implement the corresponding magic methods
In the case that one function have two or more arguments are LazyAwaitable, the lazy wait mechanism can’t ensure perfect computation/communication overlap (i.e. quickly waited the first one but long wait on the second)
- class torchrec.distributed.types.LazyNoWait(*args, **kwargs)¶
Bases:
torchrec.distributed.types.LazyAwaitable
[torchrec.distributed.types.W
]
- class torchrec.distributed.types.ModuleCopyMixin¶
Bases:
object
A mixin to allow modules to override copy behaviors in DMP.
- copy(device: torch.device) torch.nn.modules.module.Module ¶
- class torchrec.distributed.types.ModuleSharder¶
Bases:
abc.ABC
,Generic
[torchrec.distributed.types.M
]ModuleSharder is per each module, which supports sharding, e.g. EmbeddingBagCollection.
- compute_kernels(sharding_type: str, compute_device_type: str) List[str] ¶
List of supported compute kernels for a given sharding type and compute device.
- abstract property module_type: Type[torchrec.distributed.types.M]¶
- abstract classmethod shard(module: torchrec.distributed.types.M, params: Dict[str, torchrec.distributed.types.ParameterSharding], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None) torchrec.distributed.types.ShardedModule[Any, Any, Any] ¶
Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.
Default implementation is data-parallel replication.
- Parameters
module (M) – module to shard.
params (Dict[str, ParameterSharding]) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.
env (ShardingEnv) – sharding environment that has the process group.
device (torch.device) – compute device.
- Returns
sharded module implementation.
- Return type
ShardedModule[Any, Any, Any]
- shardable_parameters(module: torchrec.distributed.types.M) Dict[str, torch.nn.parameter.Parameter] ¶
List of parameters that can be sharded.
- sharding_types(compute_device_type: str) List[str] ¶
List of supported sharding types. See ShardingType for well-known examples.
- storage_usage(tensor: torch.Tensor, compute_device_type: str, compute_kernel: str) Dict[str, int] ¶
List of system resources and corresponding usage given a compute device and compute kernel.
- class torchrec.distributed.types.NoWait(obj: torchrec.distributed.types.W)¶
Bases:
torchrec.distributed.types.Awaitable
[torchrec.distributed.types.W
]
- class torchrec.distributed.types.ParameterSharding(sharding_type: str, compute_kernel: str, ranks: Optional[List[int]] = None, sharding_spec: Optional[torch.distributed._shard.sharding_spec.api.ShardingSpec] = None)¶
Bases:
object
Describes the sharding of the parameter.
- sharding_type (str): how this parameter is sharded. See ShardingType for well-known
types.
compute_kernel (str): compute kernel to be used by this parameter. ranks (Optional[List[int]]): rank of each shard. sharding_spec (Optional[ShardingSpec]): list of ShardMetadata for each shard.
Note
ShardingType.TABLE_WISE - rank where this embedding is placed ShardingType.COLUMN_WISE - rank where the embedding shards are placed, seen as individual tables ShardingType.TABLE_ROW_WISE - first rank when this embedding is placed ShardingType.ROW_WISE, ShardingType.DATA_PARALLEL - unused
- compute_kernel: str¶
- ranks: Optional[List[int]] = None¶
- sharding_spec: Optional[torch.distributed._shard.sharding_spec.api.ShardingSpec] = None¶
- sharding_type: str¶
- class torchrec.distributed.types.ParameterStorage(value)¶
Bases:
enum.Enum
Well-known physical resources, which can be used as constraints by ShardingPlanner.
- DDR = 'ddr'¶
- HBM = 'hbm'¶
- class torchrec.distributed.types.ShardedModule¶
Bases:
abc.ABC
,torch.nn.modules.module.Module
,Generic
[torchrec.distributed.types.CompIn
,torchrec.distributed.types.DistOut
,torchrec.distributed.types.Out
],torchrec.distributed.types.ModuleCopyMixin
All model-parallel modules implement this interface. Inputs and outputs are data-parallel.
Note
‘input_dist’ / ‘output_dist’ are responsible of transforming inputs / outputs from data-parallel to model parallel and vise-versa.
- abstract compute(ctx: torchrec.distributed.types.ShardedModuleContext, dist_input: torchrec.distributed.types.CompIn) torchrec.distributed.types.DistOut ¶
- compute_and_output_dist(ctx: torchrec.distributed.types.ShardedModuleContext, input: torchrec.distributed.types.CompIn) torchrec.distributed.types.LazyAwaitable[torchrec.distributed.types.Out] ¶
In case of multiple output distributions it makes sense to override this method and initiate the output distibution as soon as the corresponding compute completes.
- create_context() torchrec.distributed.types.ShardedModuleContext ¶
- forward(*input, **kwargs) torchrec.distributed.types.LazyAwaitable[torchrec.distributed.types.Out] ¶
Executes the input dist, compute, and output dist steps.
- Parameters
*input – input.
**kwargs – keyword arguments.
- Returns
awaitable of output from output dist.
- Return type
LazyAwaitable[Out]
- abstract input_dist(ctx: torchrec.distributed.types.ShardedModuleContext, *input, **kwargs) torchrec.distributed.types.Awaitable[torchrec.distributed.types.CompIn] ¶
- abstract output_dist(ctx: torchrec.distributed.types.ShardedModuleContext, output: torchrec.distributed.types.DistOut) torchrec.distributed.types.LazyAwaitable[torchrec.distributed.types.Out] ¶
- sharded_parameter_names(prefix: str = '') Iterator[str] ¶
- sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str] ¶
- training: bool¶
- class torchrec.distributed.types.ShardedModuleContext¶
Bases:
torchrec.streamable.Multistreamable
- class torchrec.distributed.types.ShardingEnv(world_size: int, rank: int, pg: Optional[torch._C._distributed_c10d.ProcessGroup] = None)¶
Bases:
object
Provides an abstraction over torch.distributed.ProcessGroup, which practically enables DistributedModelParallel to be used during inference.
- classmethod from_local(world_size: int, rank: int) torchrec.distributed.types.ShardingEnv ¶
Creates a local host-based sharding environment.
Note
Typically used during single host inference.
- classmethod from_process_group(pg: torch._C._distributed_c10d.ProcessGroup) torchrec.distributed.types.ShardingEnv ¶
Creates ProcessGroup-based sharding environment.
Note
Typically used during training.
- class torchrec.distributed.types.ShardingPlan(plan: Dict[str, Dict[str, torchrec.distributed.types.ParameterSharding]])¶
Bases:
object
Representation of sharding plan.
- plan¶
dict keyed by module path of dict of parameter sharding specs keyed by parameter name.
- Type
Dict[str, Dict[str, ParameterSharding]]
- get_plan_for_module(module_path: str) Optional[Dict[str, torchrec.distributed.types.ParameterSharding]] ¶
- Parameters
module_path (str) –
- Returns
dict of parameter sharding specs keyed by parameter name. None if sharding specs do not exist for given module_path.
- Return type
Optional[Dict[str, ParameterSharding]]
- plan: Dict[str, Dict[str, torchrec.distributed.types.ParameterSharding]]¶
- class torchrec.distributed.types.ShardingPlanner¶
Bases:
abc.ABC
Plans sharding. This plan can be saved and re-used to ensure sharding stability.
- abstract collective_plan(module: torch.nn.modules.module.Module, sharders: List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]) torchrec.distributed.types.ShardingPlan ¶
Calls self.plan(…) on rank 0 and broadcasts.
- Parameters
module (nn.Module) – module that sharding is planned for.
sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.
- Returns
the computed sharding plan.
- Return type
- abstract plan(module: torch.nn.modules.module.Module, sharders: List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]) torchrec.distributed.types.ShardingPlan ¶
Plans sharding for provided module and given sharders.
- Parameters
module (nn.Module) – module that sharding is planned for.
sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.
- Returns
the computed sharding plan.
- Return type
- class torchrec.distributed.types.ShardingType(value)¶
Bases:
enum.Enum
Well-known sharding types, used by inter-module optimizations.
- COLUMN_WISE = 'column_wise'¶
- DATA_PARALLEL = 'data_parallel'¶
- ROW_WISE = 'row_wise'¶
- TABLE_COLUMN_WISE = 'table_column_wise'¶
- TABLE_ROW_WISE = 'table_row_wise'¶
- TABLE_WISE = 'table_wise'¶
- torchrec.distributed.types.scope(method)¶
torchrec.distributed.utils¶
- torchrec.distributed.utils.add_prefix_to_state_dict(state_dict: Dict[str, Any], prefix: str) None ¶
Adds prefix to all keys in state dict, in place.
- Parameters
state_dict (Dict[str, Any]) – input state dict to update.
prefix (str) – name to filter from state dict keys.
- Returns
None.
- torchrec.distributed.utils.append_prefix(prefix: str, name: str) str ¶
Appends provided prefix to provided name.
- torchrec.distributed.utils.filter_state_dict(state_dict: collections.OrderedDict[str, torch.Tensor], name: str) collections.OrderedDict[str, torch.Tensor] ¶
Filters state dict for keys that start with provided name. Strips provided name from beginning of key in the resulting state dict.
- Parameters
state_dict (OrderedDict[str, torch.Tensor]) – input state dict to filter.
name (str) – name to filter from state dict keys.
- Returns
filtered state dict.
- Return type
OrderedDict[str, torch.Tensor]
- torchrec.distributed.utils.get_unsharded_module_names(model: torch.nn.modules.module.Module) List[str] ¶
Retrieves names of top level modules that do not contain any sharded sub-modules.
- Parameters
model (torch.nn.Module) – model to retrieve unsharded module names from.
- Returns
list of names of modules that don’t have sharded sub-modules.
- Return type
List[str]
- class torchrec.distributed.utils.sharded_model_copy(device: Optional[Union[str, int, torch.device]])¶
Bases:
object
Allows copying of DistributedModelParallel module to a target device.
Example:
# Copying model to CPU. m = DistributedModelParallel(m) with sharded_model_copy("cpu"): m_cpu = copy.deepcopy(m)