Shortcuts

torchrec.distributed

Torchrec Distributed

Torchrec distributed provides the necessary modules and operations to enable model parallelism.

These include:

  • model parallelism through DistributedModelParallel.

  • collective operations for comms, including All-to-All and Reduce-Scatter.

    • collective operations wrappers for sparse features, KJT, and various embedding types.

  • sharded implementations of various modules including ShardedEmbeddingBag for nn.EmbeddingBag, ShardedEmbeddingBagCollection for EmbeddingBagCollection

    • embedding sharders that define sharding for any sharded module implementation.

    • support for various compute kernels, which are optimized for compute device (CPU/GPU) and may include batching together embedding tables and/or optimizer fusion.

  • pipelined training through TrainPipelineSparseDist that overlaps dataloading device transfer (copy to GPU), inter*device communications (input_dist), and computation (forward, backward) for increased performance.

  • quantization support for reduced precision training and inference.

torchrec.distributed.collective_utils

This file contains utilities for constructing collective based control flows.

torchrec.distributed.collective_utils.invoke_on_rank_and_broadcast_result(pg: ProcessGroup, rank: int, func: Callable[[...], T], *args: Any, **kwargs: Any) T

Invokes a function on the designated rank and broadcasts the result to all members within the group.

Example:

id = invoke_on_rank_and_broadcast_result(pg, 0, allocate_id)
torchrec.distributed.collective_utils.is_leader(pg: Optional[ProcessGroup], leader_rank: int = 0) bool

Checks if the current processs is the leader.

Parameters:
  • pg (Optional[dist.ProcessGroup]) – the process’s rank within the pg is used to determine if the process is the leader. pg being None implies that the process is the only member in the group (e.g. a single process program).

  • leader_rank (int) – the definition of leader (defaults to 0). The caller can override it with a context-specific definition.

torchrec.distributed.collective_utils.run_on_leader(pg: ProcessGroup, rank: int)

torchrec.distributed.comm

torchrec.distributed.comm.get_group_rank(world_size: Optional[int] = None, rank: Optional[int] = None) int

Gets the group rank of the worker group. Also available with GROUP_RANK environment varible A number between 0 and get_num_groups() (See https://pytorch.org/docs/stable/elastic/run.html)

torchrec.distributed.comm.get_local_rank(world_size: Optional[int] = None, rank: Optional[int] = None) int

Gets the local rank of the local processes (see https://pytorch.org/docs/stable/elastic/run.html) This is usually the rank of the worker on its node

torchrec.distributed.comm.get_local_size(world_size: Optional[int] = None) int
torchrec.distributed.comm.get_num_groups(world_size: Optional[int] = None) int

Gets the number of worker groups. Usually equivalent to max_nnodes (See https://pytorch.org/docs/stable/elastic/run.html)

torchrec.distributed.comm.intra_and_cross_node_pg(device: Optional[device] = None, backend: Optional[str] = None) Tuple[Optional[ProcessGroup], Optional[ProcessGroup]]

Creates sub process groups (intra and cross node)

torchrec.distributed.comm_ops

class torchrec.distributed.comm_ops.All2AllDenseInfo(output_splits: List[int], batch_size: int, input_shape: List[int], input_splits: List[int])

Bases: object

The data class that collects the attributes when calling the alltoall_dense operation.

batch_size: int
input_shape: List[int]
input_splits: List[int]
output_splits: List[int]
class torchrec.distributed.comm_ops.All2AllPooledInfo(batch_size_per_rank: List[int], dim_sum_per_rank: List[int], dim_sum_per_rank_tensor: Optional[Tensor], cumsum_dim_sum_per_rank_tensor: Optional[Tensor], codecs: Optional[QuantizedCommCodecs] = None)

Bases: object

The data class that collects the attributes when calling the alltoall_pooled operation.

batch_size_per_rank

batch size in each rank

Type:

List[int]

dim_sum_per_rank

number of features (sum of dimensions) of the embedding in each rank.

Type:

List[int]

dim_sum_per_rank_tensor

the tensor version of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.

Type:

Optional[Tensor]

cumsum_dim_sum_per_rank_tensor

cumulative sum of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.

Type:

Optional[Tensor]

codecs

quantized communication codecs.

Type:

Optional[QuantizedCommCodecs]

batch_size_per_rank: List[int]
codecs: Optional[QuantizedCommCodecs] = None
cumsum_dim_sum_per_rank_tensor: Optional[Tensor]
dim_sum_per_rank: List[int]
dim_sum_per_rank_tensor: Optional[Tensor]
class torchrec.distributed.comm_ops.All2AllSequenceInfo(embedding_dim: int, lengths_after_sparse_data_all2all: Tensor, forward_recat_tensor: Optional[Tensor], backward_recat_tensor: Tensor, input_splits: List[int], output_splits: List[int], variable_batch_size: bool = False, codecs: Optional[QuantizedCommCodecs] = None, permuted_lengths_after_sparse_data_all2all: Optional[Tensor] = None)

Bases: object

The data class that collects the attributes when calling the alltoall_sequence operation.

embedding_dim

embedding dimension.

Type:

int

lengths_after_sparse_data_all2all

lengths of sparse features after AlltoAll.

Type:

Tensor

forward_recat_tensor

recat tensor for forward.

Type:

Optional[Tensor]

backward_recat_tensor

recat tensor for backward.

Type:

Tensor

input_splits

input splits.

Type:

List[int]

output_splits

output splits.

Type:

List[int]

variable_batch_size

whether variable batch size is enabled.

Type:

bool

codecs

quantized communication codecs.

Type:

Optional[QuantizedCommCodecs]

permuted_lengths_after_sparse_data_all2all

lengths of sparse features before AlltoAll.

Type:

Optional[Tensor]

backward_recat_tensor: Tensor
codecs: Optional[QuantizedCommCodecs] = None
embedding_dim: int
forward_recat_tensor: Optional[Tensor]
input_splits: List[int]
lengths_after_sparse_data_all2all: Tensor
output_splits: List[int]
permuted_lengths_after_sparse_data_all2all: Optional[Tensor] = None
variable_batch_size: bool = False
class torchrec.distributed.comm_ops.All2AllVInfo(dims_sum_per_rank: ~typing.List[int], B_global: int, B_local: int, B_local_list: ~typing.List[int], D_local_list: ~typing.List[int], input_split_sizes: ~typing.List[int] = <factory>, output_split_sizes: ~typing.List[int] = <factory>, codecs: ~typing.Optional[~torchrec.distributed.types.QuantizedCommCodecs] = None)

Bases: object

The data class that collects the attributes when calling the alltoallv operation.

dim_sum_per_rank

number of features (sum of dimensions) of the embedding in each rank.

Type:

List[int]

B_global

global batch size for each rank.

Type:

int

B_local

local batch size before scattering.

Type:

int

B_local_list

(List[int]): local batch sizes for each embedding table locally (in my current rank).

Type:

List[int]

D_local_list

embedding dimension of each embedding table locally (in my current rank).

Type:

List[int]

input_split_sizes

The input split sizes for each rank, this remembers how to split the input when doing the all_to_all_single operation.

Type:

List[int]

output_split_sizes

The output split sizes for each rank, this remembers how to fill the output when doing the all_to_all_single operation.

Type:

List[int]

B_global: int
B_local: int
B_local_list: List[int]
D_local_list: List[int]
codecs: Optional[QuantizedCommCodecs] = None
dims_sum_per_rank: List[int]
input_split_sizes: List[int]
output_split_sizes: List[int]
class torchrec.distributed.comm_ops.All2All_Pooled_Req(*args, **kwargs)

Bases: Function

static backward(ctx, *unused) Tuple[None, None, None, Tensor]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], a2ai: All2AllPooledInfo, input_embeddings: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.All2All_Pooled_Wait(*args, **kwargs)

Bases: Function

static backward(ctx, grad_output: Tensor) Tuple[None, None, Tensor]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.All2All_Seq_Req(*args, **kwargs)

Bases: Function

static backward(ctx, *unused) Tuple[None, None, None, Tensor]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], a2ai: All2AllSequenceInfo, sharded_input_embeddings: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.All2All_Seq_Req_Wait(*args, **kwargs)

Bases: Function

static backward(ctx, sharded_grad_output: Tensor) Tuple[None, None, Tensor]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.All2Allv_Req(*args, **kwargs)

Bases: Function

static backward(ctx, *grad_output)

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], a2ai: All2AllVInfo, inputs: List[Tensor]) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.All2Allv_Wait(*args, **kwargs)

Bases: Function

static backward(ctx, *grad_outputs) Tuple[None, None, Tensor]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tuple[Tensor]

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.AllGatherBaseInfo(input_size: Size, codecs: Optional[QuantizedCommCodecs] = None)

Bases: object

The data class that collects the attributes when calling the all_gatther_base_pooled operation.

input_size

the size of the input tensor.

Type:

int

codecs: Optional[QuantizedCommCodecs] = None
input_size: Size
class torchrec.distributed.comm_ops.AllGatherBase_Req(*args, **kwargs)

Bases: Function

static backward(ctx, *unused: Tensor) Tuple[Optional[Tensor], ...]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], agi: AllGatherBaseInfo, input: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.AllGatherBase_Wait(*args, **kwargs)

Bases: Function

static backward(ctx, grad_outputs: Tensor) Tuple[None, None, Tensor]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.ReduceScatterBaseInfo(input_sizes: Size, codecs: Optional[QuantizedCommCodecs] = None)

Bases: object

The data class that collects the attributes when calling the reduce_scatter_base_pooled operation.

input_sizes

the sizes of the input flatten tensor.

Type:

torch.Size

codecs: Optional[QuantizedCommCodecs] = None
input_sizes: Size
class torchrec.distributed.comm_ops.ReduceScatterBase_Req(*args, **kwargs)

Bases: Function

static backward(ctx, *unused: Tensor) Tuple[Optional[Tensor], ...]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], rsi: ReduceScatterBaseInfo, inputs: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.ReduceScatterBase_Wait(*args, **kwargs)

Bases: Function

static backward(ctx, grad_output: Tensor) Tuple[None, None, Tensor]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_Tensor: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.ReduceScatterInfo(input_sizes: List[Size], codecs: Optional[QuantizedCommCodecs] = None)

Bases: object

The data class that collects the attributes when calling the reduce_scatter_pooled operation.

input_sizes

the sizes of the input tensors. This remembers the sizes of the input tensors when running the backward pass and producing the gradient.

Type:

List[torch.Size]

codecs: Optional[QuantizedCommCodecs] = None
input_sizes: List[Size]
class torchrec.distributed.comm_ops.ReduceScatterVInfo(input_sizes: List[Size], input_splits: List[int], equal_splits: bool, total_input_size: List[int], codecs: Optional[QuantizedCommCodecs])

Bases: object

The data class that collects the attributes when calling the reduce_scatter_v_pooled operation.

input_sizes

the sizes of the input tensors. This saves the sizes of the input tensors when running the backward pass and producing the gradient.

Type:

List[torch.Size]

input_splits

the splits of the input tensors along dim 0.

Type:

List[int]

total_input_size

(List[int]): total input size.

Type:

List[int]

codecs: Optional[QuantizedCommCodecs]
equal_splits: bool
input_sizes: List[Size]
input_splits: List[int]
total_input_size: List[int]
class torchrec.distributed.comm_ops.ReduceScatterV_Req(*args, **kwargs)

Bases: Function

static backward(ctx, *unused: Tensor) Tuple[Optional[Tensor], ...]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], rsi: ReduceScatterVInfo, input: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.ReduceScatterV_Wait(*args, **kwargs)

Bases: Function

static backward(ctx, grad_output: Tensor) Tuple[None, None, Tensor]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.ReduceScatter_Req(*args, **kwargs)

Bases: Function

static backward(ctx, *unused: Tensor) Tuple[Optional[Tensor], ...]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], rsi: ReduceScatterInfo, *inputs: Any) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.ReduceScatter_Wait(*args, **kwargs)

Bases: Function

static backward(ctx, grad_output: Tensor) Tuple[None, None, Tensor]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.Request(pg: ProcessGroup, device: device)

Bases: Awaitable[W]

Defines a collective operation request for a process group on a tensor.

Parameters:

pg (dist.ProcessGroup) – The process group the request is for.

class torchrec.distributed.comm_ops.VariableBatchAll2AllPooledInfo(batch_size_per_rank_per_feature: List[List[int]], batch_size_per_feature_pre_a2a: List[int], emb_dim_per_rank_per_feature: List[List[int]], codecs: Optional[QuantizedCommCodecs] = None, input_splits: Optional[List[int]] = None, output_splits: Optional[List[int]] = None)

Bases: object

The data class that collects the attributes when calling the variable_batch_alltoall_pooled operation.

batch_size_per_rank_per_feature

batch size per rank per feature.

Type:

List[List[int]]

batch_size_per_feature_pre_a2a

local batch size before scattering.

Type:

List[int]

emb_dim_per_rank_per_feature

embedding dimension per rank per feature

Type:

List[List[int]]

codecs

quantized communication codecs.

Type:

Optional[QuantizedCommCodecs]

input_splits

input splits of tensor all to all.

Type:

Optional[List[int]]

output_splits

output splits of tensor all to all.

Type:

Optional[List[int]]

batch_size_per_feature_pre_a2a: List[int]
batch_size_per_rank_per_feature: List[List[int]]
codecs: Optional[QuantizedCommCodecs] = None
emb_dim_per_rank_per_feature: List[List[int]]
input_splits: Optional[List[int]] = None
output_splits: Optional[List[int]] = None
class torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Req(*args, **kwargs)

Bases: Function

static backward(ctx, *unused) Tuple[None, None, None, Tensor]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], a2ai: VariableBatchAll2AllPooledInfo, input_embeddings: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Wait(*args, **kwargs)

Bases: Function

static backward(ctx, grad_output: Tensor) Tuple[None, None, Tensor]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, pg: ProcessGroup, myreq: Request[Tensor], *dummy_tensor: Tensor) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

torchrec.distributed.comm_ops.all_gather_base_pooled(input: Tensor, group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor]

All-gathers tensors from all processes in a group to form a flattened pooled embeddings tensor. Input tensor is of size output_tensor_size / world_size.

Parameters:
  • input (Tensor) – tensor to gather.

  • group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.

Returns:

async work handle (Awaitable), which can be wait() later to get the resulting tensor.

Return type:

Awaitable[Tensor]

Warning

all_gather_base_pooled is experimental and subject to change.

torchrec.distributed.comm_ops.alltoall_pooled(a2a_pooled_embs_tensor: Tensor, batch_size_per_rank: List[int], dim_sum_per_rank: List[int], dim_sum_per_rank_tensor: Optional[Tensor] = None, cumsum_dim_sum_per_rank_tensor: Optional[Tensor] = None, group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor]

Performs AlltoAll operation for a single pooled embedding tensor. Each process splits the input pooled embeddings tensor based on the world size, and then scatters the split list to all processes in the group. Then concatenates the received tensors from all processes in the group and returns a single output tensor.

Parameters:
  • a2a_pooled_embs_tensor (Tensor) – input pooled embeddings. Must be pooled together before passing into this function. Its shape is B x D_local_sum, where D_local_sum is the dimension sum of all the local embedding tables.

  • batch_size_per_rank (List[int]) – batch size in each rank.

  • dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the embedding in each rank.

  • dim_sum_per_rank_tensor (Optional[Tensor]) – the tensor version of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.

  • cumsum_dim_sum_per_rank_tensor (Optional[Tensor]) – cumulative sum of dim_sum_per_rank, this is only used by the fast kernel of _recat_pooled_embedding_grad_out.

  • group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.

  • codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.

Returns:

async work handle (Awaitable), which can be wait() later to get the resulting tensor.

Return type:

Awaitable[Tensor]

Warning

alltoall_pooled is experimental and subject to change.

torchrec.distributed.comm_ops.alltoall_sequence(a2a_sequence_embs_tensor: Tensor, forward_recat_tensor: Tensor, backward_recat_tensor: Tensor, lengths_after_sparse_data_all2all: Tensor, input_splits: List[int], output_splits: List[int], variable_batch_size: bool = False, group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor]

Performs AlltoAll operation for sequence embeddings. Each process splits the input tensor based on the world size, and then scatters the split list to all processes in the group. Then concatenates the received tensors from all processes in the group and returns a single output tensor.

Note

AlltoAll operator for Sequence embedding tensors. Does not support mixed dimensions.

Parameters:
  • a2a_sequence_embs_tensor (Tensor) – input embeddings.

  • forward_recat_tensor (Tensor) – recat tensor for forward.

  • backward_recat_tensor (Tensor) – recat tensor for backward.

  • lengths_after_sparse_data_all2all (Tensor) – lengths of sparse features after AlltoAll.

  • input_splits (List[int]) – input splits.

  • output_splits (List[int]) – output splits.

  • variable_batch_size (bool) – whether variable batch size is enabled.

  • group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.

  • codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.

Returns:

async work handle (Awaitable), which can be wait() later to get the resulting tensor.

Return type:

Awaitable[List[Tensor]]

Warning

alltoall_sequence is experimental and subject to change.

torchrec.distributed.comm_ops.alltoallv(inputs: List[Tensor], out_split: Optional[List[int]] = None, per_rank_split_lengths: Optional[List[int]] = None, group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[List[Tensor]]

Performs alltoallv operation for a list of input embeddings. Each process scatters the list to all processes in the group.

Parameters:
  • inputs (List[Tensor]) – list of tensors to scatter, one per rank. The tensors in the list usually have different lengths.

  • out_split (Optional[List[int]]) – output split sizes (or dim_sum_per_rank), if not specified, we will use per_rank_split_lengths to construct a output split with the assumption that all the embs have the same dimension.

  • per_rank_split_lengths (Optional[List[int]]) – split lengths per rank. If not specified, the out_split must be specified.

  • group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.

  • codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.

Returns:

async work handle (Awaitable), which can be wait() later to get the resulting list of tensors.

Return type:

Awaitable[List[Tensor]]

Warning

alltoallv is experimental and subject to change.

torchrec.distributed.comm_ops.get_gradient_division() bool
torchrec.distributed.comm_ops.reduce_scatter_base_pooled(input: Tensor, group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor]

Reduces then scatters a flattened pooled embeddings tensor to all processes in a group. Input tensor is of size output_tensor_size * world_size.

Parameters:
  • input (Tensor) – flattened tensor to scatter.

  • group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.

  • codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.

Returns:

async work handle (Awaitable), which can be wait() later to get the resulting tensor.

Return type:

Awaitable[Tensor]

Warning

reduce_scatter_base_pooled is experimental and subject to change.

torchrec.distributed.comm_ops.reduce_scatter_pooled(inputs: List[Tensor], group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor]

Performs reduce-scatter operation for a pooled embeddings tensor split into world size number of chunks. The result of the reduce operation gets scattered to all processes in the group.

Parameters:
  • inputs (List[Tensor]) – list of tensors to scatter, one per rank.

  • group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.

  • codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.

Returns:

async work handle (Awaitable), which can be wait() later to get the resulting tensor.

Return type:

Awaitable[Tensor]

Warning

reduce_scatter_pooled is experimental and subject to change.

torchrec.distributed.comm_ops.reduce_scatter_v_pooled(input: Tensor, input_splits: List[int], group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor]

Performs reduce-scatter-v operation for a pooled embeddings tensor split unevenly into world size number of chunks. The result of the reduce operation gets scattered to all processes in the group according to input_splits.

Parameters:
  • input (Tensor) – tensor to scatter.

  • input_splits (List[int]) – input splits.

  • group (Optional[dist.ProcessGroup]) – the process group to work on. If None, the default process group will be used.

Returns:

async work handle (Awaitable), which can be wait() later to get the resulting tensor.

Return type:

Awaitable[Tensor]

Warning

reduce_scatter_v_pooled is experimental and subject to change.

torchrec.distributed.comm_ops.set_gradient_division(val: bool) None
torchrec.distributed.comm_ops.variable_batch_alltoall_pooled(a2a_pooled_embs_tensor: Tensor, batch_size_per_rank_per_feature: List[List[int]], batch_size_per_feature_pre_a2a: List[int], emb_dim_per_rank_per_feature: List[List[int]], group: Optional[ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None) Awaitable[Tensor]

torchrec.distributed.dist_data

class torchrec.distributed.dist_data.EmbeddingsAllToOne(device: device, world_size: int, cat_dim: int)

Bases: Module

Merges the pooled/sequence embedding tensor on each device into single tensor.

Parameters:
  • device (torch.device) – device on which buffer will be allocated.

  • world_size (int) – number of devices in the topology.

  • cat_dim (int) – which dimension you would like to concatenate on. For pooled embedding it is 1; for sequence embedding it is 0.

forward(tensors: List[Tensor]) Tensor

Performs AlltoOne operation on pooled/sequence embeddings tensors.

Parameters:

tensors (List[torch.Tensor]) – list of embedding tensors.

Returns:

awaitable of the merged embeddings.

Return type:

Awaitable[torch.Tensor]

set_device(device_str: str) None
training: bool
class torchrec.distributed.dist_data.EmbeddingsAllToOneReduce(device: device, world_size: int, cat_dim: int)

Bases: Module

Merges the pooled/sequence embedding tensor on each device into single tensor.

Parameters:
  • device (torch.device) – device on which buffer will be allocated.

  • world_size (int) – number of devices in the topology.

  • cat_dim (int) – which dimension you would like to concatenate on. For pooled embedding it is 1; for sequence embedding it is 0.

forward(tensors: List[Tensor]) Tensor

Performs AlltoOne operation with Reduce on pooled/sequence embeddings tensors.

Parameters:

tensors (List[torch.Tensor]) – list of embedding tensors.

Returns:

awaitable of the reduced embeddings.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.dist_data.KJTAllToAll(pg: ProcessGroup, splits: List[int], stagger: int = 1)

Bases: Module

Redistributes KeyedJaggedTensor to a ProcessGroup according to splits.

Implementation utilizes AlltoAll collective as part of torch.distributed.

The input provides the necessary tensors and input splits to distribute. The first collective call in KJTAllToAllSplitsAwaitable will transmit output splits (to allocate correct space for tensors) and batch size per rank. The following collective calls in KJTAllToAllTensorsAwaitable will transmit the actual tensors asynchronously.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • splits (List[int]) – List of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • stagger (int) – stagger value to apply to recat tensor, see _get_recat function for more detail.

Example:

keys=['A','B','C']
splits=[2,1]
kjtA2A = KJTAllToAll(pg, splits)
awaitable = kjtA2A(rank0_input)

# where:
# rank0_input is KeyedJaggedTensor holding

#         0           1           2
# 'A'    [A.V0]       None        [A.V1, A.V2]
# 'B'    None         [B.V0]      [B.V1]
# 'C'    [C.V0]       [C.V1]      None

# rank1_input is KeyedJaggedTensor holding

#         0           1           2
# 'A'     [A.V3]      [A.V4]      None
# 'B'     None        [B.V2]      [B.V3, B.V4]
# 'C'     [C.V2]      [C.V3]      None

rank0_output = awaitable.wait()

# where:
# rank0_output is KeyedJaggedTensor holding

#         0           1           2           3           4           5
# 'A'     [A.V0]      None      [A.V1, A.V2]  [A.V3]      [A.V4]      None
# 'B'     None        [B.V0]    [B.V1]        None        [B.V2]      [B.V3, B.V4]

# rank1_output is KeyedJaggedTensor holding
#         0           1           2           3           4           5
# 'C'     [C.V0]      [C.V1]      None        [C.V2]      [C.V3]      None
forward(input: KeyedJaggedTensor) Awaitable[KJTAllToAllTensorsAwaitable]

Sends input to relevant ProcessGroup ranks.

The first wait will get the output splits for the provided tensors and issue tensors AlltoAll. The second wait will get the tensors.

Parameters:

input (KeyedJaggedTensor) – KeyedJaggedTensor of values to distribute.

Returns:

awaitable of a KJTAllToAllTensorsAwaitable.

Return type:

Awaitable[KJTAllToAllTensorsAwaitable]

training: bool
class torchrec.distributed.dist_data.KJTAllToAllSplitsAwaitable(pg: ProcessGroup, input: KeyedJaggedTensor, splits: List[int], labels: List[str], tensor_splits: List[List[int]], input_tensors: List[Tensor], keys: List[str], device: device, stagger: int)

Bases: Awaitable[KJTAllToAllTensorsAwaitable]

Awaitable for KJT tensors splits AlltoAll.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • input (KeyedJaggedTensor) – input KJT.

  • splits (List[int]) – list of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • tensor_splits (Dict[str, List[int]]) – tensor splits provided by input KJT.

  • input_tensors (List[torch.Tensor]) – provided KJT tensors (ie. lengths, values) to redistribute according to splits.

  • keys (List[str]) – KJT keys after AlltoAll.

  • device (torch.device) – device on which buffers will be allocated.

  • stagger (int) – stagger value to apply to recat tensor.

class torchrec.distributed.dist_data.KJTAllToAllTensorsAwaitable(pg: ProcessGroup, input: KeyedJaggedTensor, splits: List[int], input_splits: List[List[int]], output_splits: List[List[int]], input_tensors: List[Tensor], labels: List[str], keys: List[str], device: device, stagger: int, stride_per_rank: Optional[List[int]])

Bases: Awaitable[KeyedJaggedTensor]

Awaitable for KJT tensors AlltoAll.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • input (KeyedJaggedTensor) – input KJT.

  • splits (List[int]) – list of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • input_splits (List[List[int]]) – input splits (number of values each rank will get) for each tensor in AlltoAll.

  • output_splits (List[List[int]]) – output splits (number of values per rank in output) for each tensor in AlltoAll.

  • input_tensors (List[torch.Tensor]) – provided KJT tensors (ie. lengths, values) to redistribute according to splits.

  • labels (List[str]) – labels for each provided tensor.

  • keys (List[str]) – KJT keys after AlltoAll.

  • device (torch.device) – device on which buffers will be allocated.

  • stagger (int) – stagger value to apply to recat tensor.

  • stride_per_rank (Optional[List[int]]) – stride per rank in the non variable batch per feature case.

class torchrec.distributed.dist_data.KJTOneToAll(splits: List[int], world_size: int, device: Optional[device] = None)

Bases: Module

Redistributes KeyedJaggedTensor to all devices.

Implementation utilizes OnetoAll function, which essentially P2P copies the feature to the devices.

Parameters:
  • splits (List[int]) – lengths of features to split the KeyJaggedTensor features into before copying them.

  • world_size (int) – number of devices in the topology.

  • device (torch.device) – the device on which the KJTs will be allocated.

forward(kjt: KeyedJaggedTensor) KJTList

Splits features first and then sends the slices to the corresponding devices.

Parameters:

kjt (KeyedJaggedTensor) – the input features.

Returns:

awaitable of KeyedJaggedTensor splits.

Return type:

Awaitable[List[KeyedJaggedTensor]]

training: bool
class torchrec.distributed.dist_data.PooledEmbeddingsAllGather(pg: ProcessGroup, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

The module class that wraps the all-gather communication primitive for pooled embedding communication.

Provided a local input tensor with a layout of [batch_size, dimension], we want to gather input tensors from all ranks into a flattened output tensor.

The class returns the async Awaitable handle for pooled embeddings tensor. The all-gather is only available for NCCL backend.

Parameters:

pg (dist.ProcessGroup) – the process group that the all-gather communication happens within.

Example:

init_distributed(rank=rank, size=2, backend="nccl")
pg = dist.new_group(backend="nccl")
input = torch.randn(2, 2)
m = PooledEmbeddingsAllGather(pg)
output = m(input)
tensor = output.wait()
forward(local_emb: Tensor) PooledEmbeddingsAwaitable

Performs reduce scatter operation on pooled embeddings tensor.

Parameters:

local_emb (torch.Tensor) – tensor of shape [num_buckets x batch_size, dimension].

Returns:

awaitable of pooled embeddings of tensor of shape [batch_size, dimension].

Return type:

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.PooledEmbeddingsAllToAll(pg: ProcessGroup, dim_sum_per_rank: List[int], device: Optional[device] = None, callbacks: Optional[List[Callable[[Tensor], Tensor]]] = None, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

Shards batches and collects keys of tensor with a ProcessGroup according to dim_sum_per_rank.

Implementation utilizes alltoall_pooled operation.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the embedding in each rank.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]) – callback functions.

  • codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.

Example:

dim_sum_per_rank = [2, 1]
a2a = PooledEmbeddingsAllToAll(pg, dim_sum_per_rank, device)

t0 = torch.rand((6, 2))
t1 = torch.rand((6, 1))
rank0_output = a2a(t0).wait()
rank1_output = a2a(t1).wait()
print(rank0_output.size())
    # torch.Size([3, 3])
print(rank1_output.size())
    # torch.Size([3, 3])
property callbacks: List[Callable[[Tensor], Tensor]]
forward(local_embs: Tensor, batch_size_per_rank: Optional[List[int]] = None) PooledEmbeddingsAwaitable

Performs AlltoAll pooled operation on pooled embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – tensor of values to distribute.

  • batch_size_per_rank (Optional[List[int]]) – batch size per rank, to support variable batch size.

Returns:

awaitable of pooled embeddings.

Return type:

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.PooledEmbeddingsAwaitable(tensor_awaitable: Awaitable[Tensor])

Bases: Awaitable[Tensor]

Awaitable for pooled embeddings after collective operation.

Parameters:

tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.

property callbacks: List[Callable[[Tensor], Tensor]]
class torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter(pg: ProcessGroup, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

The module class that wraps reduce-scatter communication primitives for pooled embedding communication in row-wise and twrw sharding.

For pooled embeddings, we have a local model-parallel output tensor with a layout of [num_buckets x batch_size, dimension]. We need to sum over num_buckets dimension across batches. We split the tensor along the first dimension into unequal chunks (tensor slices of different buckets) according to input_splits and reduce them into the output tensor and scatter the results for corresponding ranks.

The class returns the async Awaitable handle for pooled embeddings tensor. The reduce-scatter-v operation is only available for NCCL backend.

Parameters:
  • pg (dist.ProcessGroup) – the process group that the reduce-scatter communication happens within.

  • codecs – quantized communication codecs.

forward(local_embs: Tensor, input_splits: Optional[List[int]] = None) PooledEmbeddingsAwaitable

Performs reduce scatter operation on pooled embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – tensor of shape [num_buckets * batch_size, dimension].

  • input_splits (Optional[List[int]]) – list of splits for local_embs dim 0.

Returns:

awaitable of pooled embeddings of tensor of shape [batch_size, dimension].

Return type:

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.SeqEmbeddingsAllToOne(device: device, world_size: int)

Bases: Module

Merges the pooled/sequence embedding tensor on each device into single tensor.

Parameters:
  • device (torch.device) – device on which buffer will be allocated

  • world_size (int) – number of devices in the topology.

  • cat_dim (int) – which dimension you like to concate on. For pooled embedding it is 1; for sequence embedding it is 0.

forward(tensors: List[Tensor]) List[Tensor]

Performs AlltoOne operation on pooled embeddings tensors.

Parameters:

tensors (List[torch.Tensor]) – list of pooled embedding tensors.

Returns:

awaitable of the merged pooled embeddings.

Return type:

Awaitable[torch.Tensor]

set_device(device_str: str) None
training: bool
class torchrec.distributed.dist_data.SequenceEmbeddingsAllToAll(pg: ProcessGroup, features_per_rank: List[int], device: Optional[device] = None, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

Redistributes sequence embedding to a ProcessGroup according to splits.

Parameters:
  • pg (dist.ProcessGroup) – the process group that the AlltoAll communication happens within.

  • features_per_rank (List[int]) – list of number of features per rank.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

Example:

init_distributed(rank=rank, size=2, backend="nccl")
pg = dist.new_group(backend="nccl")
features_per_rank = [4, 4]
m = SequenceEmbeddingsAllToAll(pg, features_per_rank)
local_embs = torch.rand((6, 2))
sharding_ctx: SequenceShardingContext
output = m(
    local_embs=local_embs,
    lengths=sharding_ctx.lengths_after_input_dist,
    input_splits=sharding_ctx.input_splits,
    output_splits=sharding_ctx.output_splits,
    unbucketize_permute_tensor=None,
)
tensor = output.wait()
forward(local_embs: Tensor, lengths: Tensor, input_splits: List[int], output_splits: List[int], unbucketize_permute_tensor: Optional[Tensor] = None, batch_size_per_rank: Optional[List[int]] = None, sparse_features_recat: Optional[Tensor] = None) SequenceEmbeddingsAwaitable

Performs AlltoAll operation on sequence embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – input embeddings tensor.

  • lengths (torch.Tensor) – lengths of sparse features after AlltoAll.

  • input_splits (List[int]) – input splits of AlltoAll.

  • output_splits (List[int]) – output splits of AlltoAll.

  • unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of the KJT bucketize (for row-wise sharding only).

  • batch_size_per_rank – (Optional[List[int]]): batch size per rank.

  • sparse_features_recat (Optional[torch.Tensor]) – recat tensor used for sparse feature input dist. Must be provided if using variable batch size.

Returns:

awaitable of sequence embeddings.

Return type:

SequenceEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.SequenceEmbeddingsAwaitable(tensor_awaitable: Awaitable[Tensor], unbucketize_permute_tensor: Optional[Tensor], embedding_dim: int)

Bases: Awaitable[Tensor]

Awaitable for sequence embeddings after collective operation.

Parameters:
  • tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.

  • unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of KJT bucketize (for row-wise sharding only).

  • embedding_dim (int) – embedding dimension.

class torchrec.distributed.dist_data.SplitsAllToAllAwaitable(input_tensors: List[Tensor], pg: ProcessGroup)

Bases: Awaitable[List[List[int]]]

Awaitable for splits AlltoAll.

Parameters:
  • input_tensors (List[torch.Tensor]) – tensor of splits to redistribute.

  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

class torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsAllToAll(pg: ProcessGroup, emb_dim_per_rank_per_feature: List[List[int]], device: Optional[device] = None, callbacks: Optional[List[Callable[[Tensor], Tensor]]] = None, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

Shards batches and collects keys of tensor with a ProcessGroup according to dim_sum_per_rank.

Implementation utilizes variable_batch_alltoall_pooled operation.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • emb_dim_per_rank_per_feature (List[List[int]]) – embedding dimensions per rank per feature.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]) – callback functions.

  • codecs (Optional[QuantizedCommCodecs]) – quantized communication codecs.

Example:

emb_dim_per_rank_per_feature = [[2], [3, 3]]
a2a = VariableBatchPooledEmbeddingsAllToAll(
    pg, emb_dim_per_rank_per_feature, device
)

t0 = torch.rand(6) # 2 * (2 + 1)
t1 = torch.rand(24) # 3 * (1 + 3) + 3 * (2 + 2)
r0_batch_size_per_rank_per_feature = [[2, 1]]
r1_batch_size_per_rank_per_feature = [[1, 3], [2, 2]]
r0_batch_size_per_feature_pre_a2a = [2, 1, 3]
r1_batch_size_per_feature_pre_a2a = [1, 2, 2]

rank0_output = a2a(
    t0, r0_batch_size_per_rank_per_feature, r0_batch_size_per_feature_pre_a2a
).wait()
rank1_output = a2a(
    t1, r1_batch_size_per_rank_per_feature, r1_batch_size_per_feature_pre_a2a
).wait()

# input splits:
#   r0: [2*2, 1*1]
#   r1: [1*3 + 3*3, 2*3 + 2*3]

# output splits:
#   r0: [2*2, 1*3 + 3*3]
#   r1: [1*2, 2*3 + 2*3]

print(rank0_output.size())
    # torch.Size([16])
    # 2*2 + 1*3 + 3*3
print(rank1_output.size())
    # torch.Size([14])
    # 1*2 + 2*3 + 2*3
property callbacks: List[Callable[[Tensor], Tensor]]
forward(local_embs: Tensor, batch_size_per_rank_per_feature: List[List[int]], batch_size_per_feature_pre_a2a: List[int]) PooledEmbeddingsAwaitable

Performs AlltoAll pooled operation with variable batch size per feature on a pooled embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – tensor of values to distribute.

  • batch_size_per_rank_per_feature (List[List[int]]) – batch size per rank per feature, post a2a. Used to get the input splits.

  • batch_size_per_feature_pre_a2a (List[int]) – local batch size before scattering, used to get the output splits. Ordered by rank_0 feature, rank_1 feature, …

Returns:

awaitable of pooled embeddings.

Return type:

PooledEmbeddingsAwaitable

training: bool

torchrec.distributed.embedding

class torchrec.distributed.embedding.EmbeddingCollectionAwaitable(*args, **kwargs)

Bases: LazyAwaitable[Dict[str, JaggedTensor]]

class torchrec.distributed.embedding.EmbeddingCollectionContext(sharding_contexts: List[torchrec.distributed.sharding.sequence_sharding.SequenceShardingContext] = <factory>, input_features: List[torchrec.sparse.jagged_tensor.KeyedJaggedTensor] = <factory>, reverse_indices: List[torch.Tensor] = <factory>)

Bases: Multistreamable

input_features: List[KeyedJaggedTensor]
record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

reverse_indices: List[Tensor]
sharding_contexts: List[SequenceShardingContext]
class torchrec.distributed.embedding.EmbeddingCollectionSharder(fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, use_index_dedup: bool = False)

Bases: BaseEmbeddingSharder[EmbeddingCollection]

property module_type: Type[EmbeddingCollection]
shard(module: EmbeddingCollection, params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[device] = None) ShardedEmbeddingCollection

Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.

Default implementation is data-parallel replication.

Parameters:
  • module (M) – module to shard.

  • params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.

  • env (ShardingEnv) – sharding environment that has the process group.

  • device (torch.device) – compute device.

Returns:

sharded module implementation.

Return type:

ShardedModule[Any, Any, Any]

shardable_parameters(module: EmbeddingCollection) Dict[str, Parameter]

List of parameters that can be sharded.

sharding_types(compute_device_type: str) List[str]

List of supported sharding types. See ShardingType for well-known examples.

class torchrec.distributed.embedding.ShardedEmbeddingCollection(module: EmbeddingCollection, table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, use_index_dedup: bool = False)

Bases: ShardedEmbeddingModule[KJTList, List[Tensor], Dict[str, JaggedTensor], EmbeddingCollectionContext], FusedOptimizerModule

Sharded implementation of EmbeddingCollection. This is part of the public API to allow for manual data dist pipelining.

compute(ctx: EmbeddingCollectionContext, dist_input: KJTList) List[Tensor]
compute_and_output_dist(ctx: EmbeddingCollectionContext, input: KJTList) LazyAwaitable[Dict[str, JaggedTensor]]

In case of multiple output distributions it makes sense to override this method and initiate the output distibution as soon as the corresponding compute completes.

create_context() EmbeddingCollectionContext
property fused_optimizer: KeyedOptimizer
input_dist(ctx: EmbeddingCollectionContext, features: KeyedJaggedTensor) Awaitable[Awaitable[KJTList]]
output_dist(ctx: EmbeddingCollectionContext, output: List[Tensor]) LazyAwaitable[Dict[str, JaggedTensor]]
reset_parameters() None
training: bool
torchrec.distributed.embedding.create_embedding_sharding(sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None) EmbeddingSharding[SequenceShardingContext, KeyedJaggedTensor, Tensor, Tensor]
torchrec.distributed.embedding.create_sharding_infos_by_sharding(module: EmbeddingCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], fused_params: Optional[Dict[str, Any]]) Dict[str, List[EmbeddingShardingInfo]]
torchrec.distributed.embedding.get_ec_index_dedup() bool
torchrec.distributed.embedding.set_ec_index_dedup(val: bool) None

torchrec.distributed.embedding_lookup

class torchrec.distributed.embedding_lookup.CommOpGradientScaling(*args, **kwargs)

Bases: Function

static backward(ctx: FunctionCtx, grad_output: Tensor) Tuple[Tensor, None]

Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx: FunctionCtx, input_tensor: Tensor, scale_gradient_factor: int) Tensor

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class torchrec.distributed.embedding_lookup.GroupedEmbeddingsLookup(grouped_configs: List[GroupedEmbeddingConfig], pg: Optional[ProcessGroup] = None, device: Optional[device] = None)

Bases: BaseEmbeddingLookup[KeyedJaggedTensor, Tensor]

Lookup modules for Sequence embeddings (i.e Embeddings)

forward(sparse_features: KeyedJaggedTensor) Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

load_state_dict(state_dict: OrderedDict[str, Union[torch.Tensor, ShardedTensor]], strict: bool = True) _IncompatibleKeys

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:
  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor) – Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:
  • prefix (str) – prefix to prepend to all parameter names.

  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

  • remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.

Yields:

(str, Parameter) – Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
named_parameters_by_table() Iterator[Tuple[str, TableBatchedEmbeddingSlice]]

Like named_parameters(), but yields table_name and embedding_weights which are wrapped in TableBatchedEmbeddingSlice. For a single table with multiple shards (i.e CW) these are combined into one table/weight. Used in composability.

state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

Returns a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
training: bool
class torchrec.distributed.embedding_lookup.GroupedPooledEmbeddingsLookup(grouped_configs: List[GroupedEmbeddingConfig], device: Optional[device] = None, pg: Optional[ProcessGroup] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None, scale_weight_gradients: bool = True)

Bases: BaseEmbeddingLookup[KeyedJaggedTensor, Tensor]

Lookup modules for Pooled embeddings (i.e EmbeddingBags)

forward(sparse_features: KeyedJaggedTensor) Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

load_state_dict(state_dict: OrderedDict[str, Union[ShardedTensor, torch.Tensor]], strict: bool = True) _IncompatibleKeys

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:
  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor) – Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:
  • prefix (str) – prefix to prepend to all parameter names.

  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

  • remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.

Yields:

(str, Parameter) – Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
named_parameters_by_table() Iterator[Tuple[str, TableBatchedEmbeddingSlice]]

Like named_parameters(), but yields table_name and embedding_weights which are wrapped in TableBatchedEmbeddingSlice. For a single table with multiple shards (i.e CW) these are combined into one table/weight. Used in composability.

prefetch(sparse_features: KeyedJaggedTensor, forward_stream: Optional[Stream] = None) None
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

Returns a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
training: bool
class torchrec.distributed.embedding_lookup.InferGroupedEmbeddingsLookup(grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]], world_size: int, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None)

Bases: InferGroupedLookupMixin, BaseEmbeddingLookup[KJTList, List[Tensor]], TBEToRegisterMixIn

get_tbes_to_register() Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]
training: bool
class torchrec.distributed.embedding_lookup.InferGroupedLookupMixin

Bases: ABC

forward(sparse_features: KJTList) List[Tensor]
load_state_dict(state_dict: OrderedDict[str, torch.Tensor], strict: bool = True) _IncompatibleKeys
named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Tensor]]
named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Parameter]]
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]
class torchrec.distributed.embedding_lookup.InferGroupedPooledEmbeddingsLookup(grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]], world_size: int, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None)

Bases: InferGroupedLookupMixin, BaseEmbeddingLookup[KJTList, List[Tensor]], TBEToRegisterMixIn

get_tbes_to_register() Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]
training: bool
class torchrec.distributed.embedding_lookup.MetaInferGroupedEmbeddingsLookup(grouped_configs: List[GroupedEmbeddingConfig], device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None)

Bases: BaseEmbeddingLookup[KeyedJaggedTensor, Tensor], TBEToRegisterMixIn

meta embedding lookup module for inference since inference lookup has references for multiple TBE ops over all gpu workers. inference grouped embedding lookup module contains meta modules allocated over gpu workers.

forward(sparse_features: KeyedJaggedTensor) Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_tbes_to_register() Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]
load_state_dict(state_dict: OrderedDict[str, Union[ShardedTensor, torch.Tensor]], strict: bool = True) _IncompatibleKeys

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:
  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor) – Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:
  • prefix (str) – prefix to prepend to all parameter names.

  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

  • remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.

Yields:

(str, Parameter) – Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

Returns a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
training: bool
class torchrec.distributed.embedding_lookup.MetaInferGroupedPooledEmbeddingsLookup(grouped_configs: List[GroupedEmbeddingConfig], device: Optional[device] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None, fused_params: Optional[Dict[str, Any]] = None)

Bases: BaseEmbeddingLookup[KeyedJaggedTensor, Tensor], TBEToRegisterMixIn

meta embedding bag lookup module for inference since inference lookup has references for multiple TBE ops over all gpu workers. inference grouped embedding bag lookup module contains meta modules allocated over gpu workers.

forward(sparse_features: KeyedJaggedTensor) Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_tbes_to_register() Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]
load_state_dict(state_dict: OrderedDict[str, Union[ShardedTensor, torch.Tensor]], strict: bool = True) _IncompatibleKeys

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:
  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor) – Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:
  • prefix (str) – prefix to prepend to all parameter names.

  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

  • remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.

Yields:

(str, Parameter) – Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

Returns a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
training: bool
torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle(embeddings: List[Tensor], dummy_embs_tensor: Tensor, dim: int = 0) Tensor
torchrec.distributed.embedding_lookup.fx_wrap_tensor_view2d(x: Tensor, dim0: int, dim1: int) Tensor

torchrec.distributed.embedding_sharding

class torchrec.distributed.embedding_sharding.BaseEmbeddingDist(*args, **kwargs)

Bases: ABC, Module, Generic[C, T, W]

Converts output of EmbeddingLookup from model-parallel to data-parallel.

abstract forward(local_embs: T, sharding_ctx: Optional[C] = None) Union[Awaitable[W], W]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist(*args, **kwargs)

Bases: ABC, Module, Generic[F]

Converts input from data-parallel to model-parallel.

abstract forward(sparse_features: KeyedJaggedTensor) Union[Awaitable[Awaitable[F]], F]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.distributed.embedding_sharding.EmbeddingSharding(qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: ABC, Generic[C, F, T, W], FeatureShardingMixIn

Used to implement different sharding types for EmbeddingBagCollection, e.g. table_wise.

abstract create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[F]
abstract create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup[F, T]
abstract create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[C, T, W]
abstract embedding_dims() List[int]
abstract embedding_names() List[str]
abstract embedding_names_per_rank() List[List[str]]
abstract embedding_shard_metadata() List[Optional[ShardMetadata]]
embedding_tables() List[ShardedEmbeddingTable]
property qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]]
class torchrec.distributed.embedding_sharding.EmbeddingShardingContext(batch_size_per_rank: List[int] = <factory>, batch_size_per_rank_per_feature: List[List[int]] = <factory>, batch_size_per_feature_pre_a2a: List[int] = <factory>, variable_batch_per_feature: bool = False)

Bases: Multistreamable

batch_size_per_feature_pre_a2a: List[int]
batch_size_per_rank: List[int]
batch_size_per_rank_per_feature: List[List[int]]
record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

variable_batch_per_feature: bool = False
class torchrec.distributed.embedding_sharding.EmbeddingShardingInfo(embedding_config: torchrec.modules.embedding_configs.EmbeddingTableConfig, param_sharding: torchrec.distributed.types.ParameterSharding, param: torch.Tensor, fused_params: Union[Dict[str, Any], NoneType] = None)

Bases: object

embedding_config: EmbeddingTableConfig
fused_params: Optional[Dict[str, Any]] = None
param: Tensor
param_sharding: ParameterSharding
class torchrec.distributed.embedding_sharding.FusedKJTListSplitsAwaitable(requests: List[KJTListSplitsAwaitable[C]], contexts: List[C], pg: Optional[ProcessGroup])

Bases: Awaitable[List[KJTListAwaitable]]

class torchrec.distributed.embedding_sharding.KJTListAwaitable(awaitables: List[Awaitable[KeyedJaggedTensor]], ctx: C)

Bases: Awaitable[KJTList]

Awaitable of KJTList.

Parameters:
  • awaitables (List[Awaitable[KeyedJaggedTensor]]) – list of Awaitable of sparse features.

  • ctx (C) – sharding context to save the batch size info from the KJT for the embedding AlltoAll.

class torchrec.distributed.embedding_sharding.KJTListSplitsAwaitable(awaitables: List[Awaitable[Awaitable[KeyedJaggedTensor]]], ctx: C)

Bases: Awaitable[Awaitable[KJTList]], Generic[C]

Awaitable of Awaitable of KJTList.

Parameters:
  • awaitables (List[Awaitable[Awaitable[KeyedJaggedTensor]]]) – result from calling forward on KJTAllToAll with sparse features to redistribute.

  • ctx (C) – sharding context to save the metadata from the input dist to for the embedding AlltoAll.

class torchrec.distributed.embedding_sharding.KJTSplitsAllToAllMeta(pg: torch.distributed.distributed_c10d.ProcessGroup, _input: torchrec.sparse.jagged_tensor.KeyedJaggedTensor, splits: List[int], splits_tensors: List[torch.Tensor], input_splits: List[List[int]], input_tensors: List[torch.Tensor], labels: List[str], keys: List[str], device: torch.device, stagger: int, splits_cumsum: List[int])

Bases: object

device: device
input_splits: List[List[int]]
input_tensors: List[Tensor]
keys: List[str]
labels: List[str]
pg: ProcessGroup
splits: List[int]
splits_cumsum: List[int]
splits_tensors: List[Tensor]
stagger: int
class torchrec.distributed.embedding_sharding.ListOfKJTListAwaitable(awaitables: List[Awaitable[KJTList]])

Bases: Awaitable[ListOfKJTList]

This module handles the tables-wise sharding input features distribution for inference.

Parameters:

awaitables (List[Awaitable[KJTList]]) – list of Awaitable of KJTList.

class torchrec.distributed.embedding_sharding.ListOfKJTListSplitsAwaitable(awaitables: List[Awaitable[Awaitable[KJTList]]])

Bases: Awaitable[Awaitable[ListOfKJTList]]

Awaitable of Awaitable of ListOfKJTList.

Parameters:

awaitables (List[Awaitable[Awaitable[KJTList]]]) – list of Awaitable of Awaitable of sparse features list.

torchrec.distributed.embedding_sharding.bucketize_kjt_before_all2all(kjt: KeyedJaggedTensor, num_buckets: int, block_sizes: Tensor, output_permute: bool = False, bucketize_pos: bool = False) Tuple[KeyedJaggedTensor, Optional[Tensor]]

Bucketizes the values in KeyedJaggedTensor into num_buckets buckets, lengths are readjusted based on the bucketization results.

Note: This function should be used only for row-wise sharding before calling KJTAllToAll.

Parameters:
  • num_buckets (int) – number of buckets to bucketize the values into.

  • block_sizes – (torch.Tensor): bucket sizes for the keyed dimension.

  • output_permute (bool) – output the memory location mapping from the unbucketized values to bucketized values or not.

  • bucketize_pos (bool) – output the changed position of the bucketized values or not.

Returns:

the bucketized KeyedJaggedTensor and the optional permute mapping from the unbucketized values to bucketized value.

Return type:

Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]

torchrec.distributed.embedding_sharding.group_tables(tables_per_rank: List[List[ShardedEmbeddingTable]]) List[List[GroupedEmbeddingConfig]]

Groups tables by DataType, PoolingType, and EmbeddingComputeKernel.

Parameters:

tables_per_rank (List[List[ShardedEmbeddingTable]]) – list of sharded embedding tables per rank with consistent weightedness.

Returns:

per rank list of GroupedEmbeddingConfig for features.

Return type:

List[List[GroupedEmbeddingConfig]]

torchrec.distributed.embedding_types

class torchrec.distributed.embedding_types.BaseEmbeddingLookup(*args, **kwargs)

Bases: ABC, Module, Generic[F, T]

Interface implemented by different embedding implementations: e.g. one, which relies on nn.EmbeddingBag or table-batched one, etc.

abstract forward(sparse_features: F) T

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.distributed.embedding_types.BaseEmbeddingSharder(fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: ModuleSharder[M]

compute_kernels(sharding_type: str, compute_device_type: str) List[str]

List of supported compute kernels for a given sharding type and compute device.

property fused_params: Optional[Dict[str, Any]]
sharding_types(compute_device_type: str) List[str]

List of supported sharding types. See ShardingType for well-known examples.

storage_usage(tensor: Tensor, compute_device_type: str, compute_kernel: str) Dict[str, int]

List of system resources and corresponding usage given a compute device and compute kernel

class torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor(*args, **kwargs)

Bases: Module

Abstract base class for grouped feature processor

abstract forward(features: KeyedJaggedTensor) KeyedJaggedTensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.distributed.embedding_types.BaseQuantEmbeddingSharder(fused_params: Optional[Dict[str, Any]] = None, shardable_params: Optional[List[str]] = None)

Bases: ModuleSharder[M]

compute_kernels(sharding_type: str, compute_device_type: str) List[str]

List of supported compute kernels for a given sharding type and compute device.

property fused_params: Optional[Dict[str, Any]]
shardable_parameters(module: M) Dict[str, Parameter]

List of parameters that can be sharded.

sharding_types(compute_device_type: str) List[str]

List of supported sharding types. See ShardingType for well-known examples.

storage_usage(tensor: Tensor, compute_device_type: str, compute_kernel: str) Dict[str, int]

List of system resources and corresponding usage given a compute device and compute kernel

class torchrec.distributed.embedding_types.EmbeddingAttributes(compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel = <EmbeddingComputeKernel.DENSE: 'dense'>)

Bases: object

compute_kernel: EmbeddingComputeKernel = 'dense'
class torchrec.distributed.embedding_types.EmbeddingComputeKernel(value)

Bases: Enum

An enumeration.

DENSE = 'dense'
FUSED = 'fused'
FUSED_UVM = 'fused_uvm'
FUSED_UVM_CACHING = 'fused_uvm_caching'
QUANT = 'quant'
QUANT_UVM = 'quant_uvm'
QUANT_UVM_CACHING = 'quant_uvm_caching'
class torchrec.distributed.embedding_types.FeatureShardingMixIn

Bases: object

Feature Sharding Interface to provide sharding-aware feature metadata.

feature_names() List[str]
feature_names_per_rank() List[List[str]]
features_per_rank() List[int]
class torchrec.distributed.embedding_types.GroupedEmbeddingConfig(data_type: torchrec.distributed.types.DataType, pooling: torchrec.modules.embedding_configs.PoolingType, is_weighted: bool, has_feature_processor: bool, compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel, embedding_tables: List[torchrec.distributed.embedding_types.ShardedEmbeddingTable], fused_params: Union[Dict[str, Any], NoneType] = None)

Bases: object

compute_kernel: EmbeddingComputeKernel
data_type: DataType
dim_sum() int
embedding_dims() List[int]
embedding_names() List[str]
embedding_shard_metadata() List[Optional[ShardMetadata]]
embedding_tables: List[ShardedEmbeddingTable]
feature_hash_sizes() List[int]
feature_names() List[str]
fused_params: Optional[Dict[str, Any]] = None
has_feature_processor: bool
is_weighted: bool
num_features() int
pooling: PoolingType
class torchrec.distributed.embedding_types.KJTList(features: List[KeyedJaggedTensor])

Bases: Multistreamable

record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

class torchrec.distributed.embedding_types.ListOfKJTList(features: List[KJTList])

Bases: Multistreamable

record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

class torchrec.distributed.embedding_types.ModuleShardingMixIn

Bases: object

The interface to access a sharded module’s sharding scheme.

property shardings: Dict[str, FeatureShardingMixIn]
class torchrec.distributed.embedding_types.OptimType(value)

Bases: Enum

An enumeration.

ADAGRAD = 'ADAGRAD'
ADAM = 'ADAM'
LAMB = 'LAMB'
LARS_SGD = 'LARS_SGD'
LION = 'LION'
PARTIAL_ROWWISE_ADAM = 'PARTIAL_ROWWISE_ADAM'
PARTIAL_ROWWISE_LAMB = 'PARTIAL_ROWWISE_LAMB'
ROWWISE_ADAGRAD = 'ROWWISE_ADAGRAD'
SGD = 'SGD'
SHAMPOO = 'SHAMPOO'
class torchrec.distributed.embedding_types.ShardedConfig(local_rows: int = 0, local_cols: int = 0)

Bases: object

local_cols: int = 0
local_rows: int = 0
class torchrec.distributed.embedding_types.ShardedEmbeddingModule(qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: ShardedModule[CompIn, DistOut, Out, ShrdCtx], ModuleShardingMixIn

All model-parallel embedding modules implement this interface. Inputs and outputs are data-parallel.

Args::

qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]) : Mapping of CommOp name to QuantizedCommCodecs

extra_repr() str

Pretty prints representation of the module’s lookup modules, input_dists and output_dists

prefetch(dist_input: KJTList, forward_stream: Optional[Stream] = None) None

Prefetch input features for each lookup module.

training: bool
class torchrec.distributed.embedding_types.ShardedEmbeddingTable(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.distributed.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False, pooling: torchrec.modules.embedding_configs.PoolingType = <PoolingType.SUM: 'SUM'>, is_weighted: bool = False, has_feature_processor: bool = False, embedding_names: List[str] = <factory>, compute_kernel: torchrec.distributed.embedding_types.EmbeddingComputeKernel = <EmbeddingComputeKernel.DENSE: 'dense'>, local_rows: int = 0, local_cols: int = 0, local_metadata: Union[torch.distributed._shard.metadata.ShardMetadata, NoneType] = None, global_metadata: Union[torch.distributed._shard.sharded_tensor.metadata.ShardedTensorMetadata, NoneType] = None, fused_params: Union[Dict[str, Any], NoneType] = None)

Bases: ShardedMetaConfig, EmbeddingAttributes, EmbeddingTableConfig

fused_params: Optional[Dict[str, Any]] = None
class torchrec.distributed.embedding_types.ShardedMetaConfig(local_rows: int = 0, local_cols: int = 0, local_metadata: Union[torch.distributed._shard.metadata.ShardMetadata, NoneType] = None, global_metadata: Union[torch.distributed._shard.sharded_tensor.metadata.ShardedTensorMetadata, NoneType] = None)

Bases: ShardedConfig

global_metadata: Optional[ShardedTensorMetadata] = None
local_metadata: Optional[ShardMetadata] = None
torchrec.distributed.embedding_types.compute_kernel_to_embedding_location(compute_kernel: EmbeddingComputeKernel) EmbeddingLocation

torchrec.distributed.embeddingbag

class torchrec.distributed.embeddingbag.EmbeddingAwaitable(*args, **kwargs)

Bases: LazyAwaitable[Tensor]

class torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable(*args, **kwargs)

Bases: LazyAwaitable[KeyedTensor]

class torchrec.distributed.embeddingbag.EmbeddingBagCollectionContext(sharding_contexts: List[Union[torchrec.distributed.embedding_sharding.EmbeddingShardingContext, NoneType]] = <factory>)

Bases: Multistreamable

record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

sharding_contexts: List[Optional[EmbeddingShardingContext]]
class torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder(fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseEmbeddingSharder[EmbeddingBagCollection]

This implementation uses non-fused EmbeddingBagCollection

property module_type: Type[EmbeddingBagCollection]
shard(module: EmbeddingBagCollection, params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[device] = None) ShardedEmbeddingBagCollection

Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.

Default implementation is data-parallel replication.

Parameters:
  • module (M) – module to shard.

  • params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.

  • env (ShardingEnv) – sharding environment that has the process group.

  • device (torch.device) – compute device.

Returns:

sharded module implementation.

Return type:

ShardedModule[Any, Any, Any]

shardable_parameters(module: EmbeddingBagCollection) Dict[str, Parameter]

List of parameters that can be sharded.

class torchrec.distributed.embeddingbag.EmbeddingBagSharder(fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseEmbeddingSharder[EmbeddingBag]

This implementation uses non-fused nn.EmbeddingBag

property module_type: Type[EmbeddingBag]
shard(module: EmbeddingBag, params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[device] = None) ShardedEmbeddingBag

Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.

Default implementation is data-parallel replication.

Parameters:
  • module (M) – module to shard.

  • params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.

  • env (ShardingEnv) – sharding environment that has the process group.

  • device (torch.device) – compute device.

Returns:

sharded module implementation.

Return type:

ShardedModule[Any, Any, Any]

shardable_parameters(module: EmbeddingBag) Dict[str, Parameter]

List of parameters that can be sharded.

class torchrec.distributed.embeddingbag.ShardedEmbeddingBag(module: EmbeddingBag, table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None)

Bases: ShardedEmbeddingModule[KeyedJaggedTensor, Tensor, Tensor, NullShardedModuleContext], FusedOptimizerModule

Sharded implementation of nn.EmbeddingBag. This is part of the public API to allow for manual data dist pipelining.

compute(ctx: NullShardedModuleContext, dist_input: KeyedJaggedTensor) Tensor
create_context() NullShardedModuleContext
property fused_optimizer: KeyedOptimizer
input_dist(ctx: NullShardedModuleContext, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) Awaitable[Awaitable[KeyedJaggedTensor]]
load_state_dict(state_dict: OrderedDict[str, torch.Tensor], strict: bool = True) _IncompatibleKeys

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:
  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor) – Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_modules(memo: Optional[Set[Module]] = None, prefix: str = '', remove_duplicate: bool = True) Iterator[Tuple[str, Module]]

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Parameters:
  • memo – a memo to store the set of modules already added to the result

  • prefix – a prefix that will be added to the name of the module

  • remove_duplicate – whether to remove the duplicated module instances in the result or not

Yields:

(str, Module) – Tuple of name and module

Note

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example:

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
...     print(idx, '->', m)

0 -> ('', Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:
  • prefix (str) – prefix to prepend to all parameter names.

  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

  • remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.

Yields:

(str, Parameter) – Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
output_dist(ctx: NullShardedModuleContext, output: Tensor) LazyAwaitable[Tensor]
sharded_parameter_names(prefix: str = '') Iterator[str]
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

Returns a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
training: bool
class torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection(module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: ShardedEmbeddingModule[KJTList, List[Tensor], KeyedTensor, EmbeddingBagCollectionContext], FusedOptimizerModule

Sharded implementation of EmbeddingBagCollection. This is part of the public API to allow for manual data dist pipelining.

compute(ctx: EmbeddingBagCollectionContext, dist_input: KJTList) List[Tensor]
compute_and_output_dist(ctx: EmbeddingBagCollectionContext, input: KJTList) LazyAwaitable[KeyedTensor]

In case of multiple output distributions it makes sense to override this method and initiate the output distibution as soon as the corresponding compute completes.

create_context() EmbeddingBagCollectionContext
property fused_optimizer: KeyedOptimizer
input_dist(ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor) Awaitable[Awaitable[KJTList]]
output_dist(ctx: EmbeddingBagCollectionContext, output: List[Tensor]) LazyAwaitable[KeyedTensor]
reset_parameters() None
training: bool
torchrec.distributed.embeddingbag.construct_output_kt(embeddings: List[Tensor], embedding_names: List[str], embedding_dims: List[int]) KeyedTensor
torchrec.distributed.embeddingbag.create_embedding_bag_sharding(sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None) EmbeddingSharding[EmbeddingShardingContext, KeyedJaggedTensor, Tensor, Tensor]
torchrec.distributed.embeddingbag.create_sharding_infos_by_sharding(module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], prefix: str, fused_params: Optional[Dict[str, Any]], suffix: Optional[str] = 'weight') Dict[str, List[EmbeddingShardingInfo]]
torchrec.distributed.embeddingbag.replace_placement_with_meta_device(sharding_infos: List[EmbeddingShardingInfo]) None

Placement device and tensor device could be unmatched in some scenarios, e.g. passing meta device to DMP and passing cuda to EmbeddingShardingPlanner. We need to make device consistent after getting sharding planner.

torchrec.distributed.grouped_position_weighted

class torchrec.distributed.grouped_position_weighted.GroupedPositionWeightedModule(max_feature_lengths: Dict[str, int], device: Optional[device] = None)

Bases: BaseGroupedFeatureProcessor

forward(features: KeyedJaggedTensor) KeyedJaggedTensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:
  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor) – Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:
  • prefix (str) – prefix to prepend to all parameter names.

  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

  • remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.

Yields:

(str, Parameter) – Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

Returns a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
training: bool

torchrec.distributed.model_parallel

class torchrec.distributed.model_parallel.DataParallelWrapper

Bases: ABC

Interface implemented by custom data parallel wrappers.

abstract wrap(dmp: DistributedModelParallel, env: ShardingEnv, device: device) None
class torchrec.distributed.model_parallel.DefaultDataParallelWrapper(bucket_cap_mb: int = 25, static_graph: bool = True, find_unused_parameters: bool = False, allreduce_comm_precision: Optional[str] = None)

Bases: DataParallelWrapper

Default data parallel wrapper, which applies data parallel to all unsharded modules.

wrap(dmp: DistributedModelParallel, env: ShardingEnv, device: device) None
class torchrec.distributed.model_parallel.DistributedModelParallel(module: Module, env: Optional[ShardingEnv] = None, device: Optional[device] = None, plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None)

Bases: Module, FusedOptimizerModule

Entry point to model parallelism.

Parameters:
  • module (nn.Module) – module to wrap.

  • env (Optional[ShardingEnv]) – sharding environment that has the process group.

  • device (Optional[torch.device]) – compute device, defaults to cpu.

  • plan (Optional[ShardingPlan]) – plan to use when sharding, defaults to EmbeddingShardingPlanner.collective_plan().

  • sharders (Optional[List[ModuleSharder[nn.Module]]]) – ModuleSharders available to shard with, defaults to EmbeddingBagCollectionSharder().

  • init_data_parallel (bool) – data-parallel modules can be lazy, i.e. they delay parameter initialization until the first forward pass. Pass True to delay initialization of data parallel modules. Do first forward pass and then call DistributedModelParallel.init_data_parallel().

  • init_parameters (bool) – initialize parameters for modules still on meta device.

  • data_parallel_wrapper (Optional[DataParallelWrapper]) – custom wrapper for data parallel modules.

Example:

@torch.no_grad()
def init_weights(m):
    if isinstance(m, nn.Linear):
        m.weight.fill_(1.0)
    elif isinstance(m, EmbeddingBagCollection):
        for param in m.parameters():
            init.kaiming_normal_(param)

m = MyModel(device='meta')
m = DistributedModelParallel(m)
m.apply(init_weights)
bare_named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Parameter]]
copy(device: device) DistributedModelParallel

Recursively copy submodules to new device by calling per-module customized copy process, since some modules needs to use the original references (like ShardedModule for inference).

forward(*args, **kwargs) Any

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

property fused_optimizer: KeyedOptimizer
init_data_parallel() None

See init_data_parallel c-tor argument for usage. It’s safe to call this method multiple times.

load_state_dict(state_dict: OrderedDict[str, torch.Tensor], prefix: str = '', strict: bool = True) _IncompatibleKeys

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

property module: Module

Property to directly access sharded module, which will not be wrapped in DDP, FSDP, DMP, or any other parallelism wrappers.

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:
  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor) – Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:
  • prefix (str) – prefix to prepend to all parameter names.

  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

  • remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.

Yields:

(str, Parameter) – Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
property plan: ShardingPlan
sparse_grad_parameter_names(destination: Optional[List[str]] = None, prefix: str = '') List[str]
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

Returns a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
training: bool
torchrec.distributed.model_parallel.get_module(module: Module) Module

Unwraps DMP module.

Does not unwrap data parallel wrappers (i.e. DDP/FSDP), so overriding implementations by the wrappers can be used.

torchrec.distributed.model_parallel.get_unwrapped_module(module: Module) Module

Unwraps module wrapped by DMP, DDP, or FSDP.

torchrec.distributed.quant_embeddingbag

class torchrec.distributed.quant_embeddingbag.QuantEmbeddingBagCollectionSharder(fused_params: Optional[Dict[str, Any]] = None, shardable_params: Optional[List[str]] = None)

Bases: BaseQuantEmbeddingSharder[EmbeddingBagCollection]

property module_type: Type[EmbeddingBagCollection]
shard(module: EmbeddingBagCollection, params: Dict[str, ParameterSharding], env: ShardingEnv, device: Optional[device] = None) ShardedQuantEmbeddingBagCollection

Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.

Default implementation is data-parallel replication.

Parameters:
  • module (M) – module to shard.

  • params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.

  • env (ShardingEnv) – sharding environment that has the process group.

  • device (torch.device) – compute device.

Returns:

sharded module implementation.

Return type:

ShardedModule[Any, Any, Any]

class torchrec.distributed.quant_embeddingbag.ShardedQuantEmbeddingBagCollection(module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], env: ShardingEnv, fused_params: Optional[Dict[str, Any]] = None, device: Optional[device] = None)

Bases: ShardedQuantEmbeddingModuleState[ListOfKJTList, List[List[Tensor]], KeyedTensor, NullShardedModuleContext]

Sharded implementation of EmbeddingBagCollection. This is part of the public API to allow for manual data dist pipelining.

compute(ctx: NullShardedModuleContext, dist_input: ListOfKJTList) List[List[Tensor]]
compute_and_output_dist(ctx: NullShardedModuleContext, input: ListOfKJTList) KeyedTensor

In case of multiple output distributions it makes sense to override this method and initiate the output distibution as soon as the corresponding compute completes.

copy(device: device) Module
create_context() NullShardedModuleContext
embedding_bag_configs() List[EmbeddingBagConfig]
forward(*input, **kwargs) KeyedTensor

Executes the input dist, compute, and output dist steps.

Parameters:
  • *input – input.

  • **kwargs – keyword arguments.

Returns:

awaitable of output from output dist.

Return type:

LazyAwaitable[Out]

input_dist(ctx: NullShardedModuleContext, features: KeyedJaggedTensor) ListOfKJTList
output_dist(ctx: NullShardedModuleContext, output: List[List[Tensor]]) KeyedTensor
sharding_type_to_sharding_infos() Dict[str, List[EmbeddingShardingInfo]]
property shardings: Dict[str, FeatureShardingMixIn]
tbes_configs() Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]
training: bool
torchrec.distributed.quant_embeddingbag.create_infer_embedding_bag_sharding(sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv) EmbeddingSharding[NullShardingContext, KJTList, List[Tensor], Tensor]

torchrec.distributed.train_pipeline

NOTE: Due to an internal packaging issue, train_pipeline.py must be compatible with older versions of TorchRec. Importing new modules from other files may break model publishing flows.

class torchrec.distributed.train_pipeline.ArgInfo(input_attrs: List[str], is_getitems: List[bool], name: Optional[str])

Bases: object

Representation of args from a node.

input_attrs

attributes of input batch, e.g. batch.attr1.attr2 will produce [“attr1”, “attr2”].

Type:

List[str]

is_getitems

batch[attr1].attr2 will produce [True, False].

Type:

List[bool]

name

name for kwarg of pipelined forward() call or None for a positional arg.

Type:

Optional[str]

input_attrs: List[str]
is_getitems: List[bool]
name: Optional[str]
class torchrec.distributed.train_pipeline.BaseForward(name: str, args: List[ArgInfo], module: ShardedModule, context: TrainPipelineContext, stream: Optional[Stream])

Bases: object

property args: List[ArgInfo]
property name: str
class torchrec.distributed.train_pipeline.FusedKJTListSplitsAwaitable(requests: List[KJTListSplitsAwaitable[C]], contexts: List[C], pg: Optional[ProcessGroup])

Bases: Awaitable[List[KJTListAwaitable]]

class torchrec.distributed.train_pipeline.KJTAllToAllForward(pg: ProcessGroup, splits: List[int], stagger: int = 1)

Bases: object

class torchrec.distributed.train_pipeline.KJTSplitsAllToAllMeta(pg: torch.distributed.distributed_c10d.ProcessGroup, _input: torchrec.sparse.jagged_tensor.KeyedJaggedTensor, splits: List[int], splits_tensors: List[torch.Tensor], input_splits: List[List[int]], input_tensors: List[torch.Tensor], labels: List[str], keys: List[str], device: torch.device, stagger: int)

Bases: object

device: device
input_splits: List[List[int]]
input_tensors: List[Tensor]
keys: List[str]
labels: List[str]
pg: ProcessGroup
splits: List[int]
splits_tensors: List[Tensor]
stagger: int
class torchrec.distributed.train_pipeline.PipelinedForward(name: str, args: List[ArgInfo], module: ShardedModule, context: TrainPipelineContext, stream: Optional[Stream])

Bases: BaseForward

class torchrec.distributed.train_pipeline.PrefetchPipelinedForward(name: str, args: List[ArgInfo], module: ShardedModule, context: PrefetchTrainPipelineContext, prefetch_stream: Optional[Stream])

Bases: BaseForward

class torchrec.distributed.train_pipeline.PrefetchTrainPipelineContext(input_dist_splits_requests: Dict[str, torchrec.distributed.types.Awaitable[Any]] = <factory>, input_dist_tensors_requests: Dict[str, torchrec.distributed.types.Awaitable[Any]] = <factory>, module_contexts: Dict[str, torchrec.streamable.Multistreamable] = <factory>, module_contexts_next_batch: Dict[str, torchrec.streamable.Multistreamable] = <factory>, fused_splits_awaitables: List[Tuple[List[str], torchrec.distributed.train_pipeline.FusedKJTListSplitsAwaitable]] = <factory>, module_input_post_prefetch: Dict[str, torchrec.streamable.Multistreamable] = <factory>, module_contexts_post_prefetch: Dict[str, torchrec.streamable.Multistreamable] = <factory>)

Bases: TrainPipelineContext

module_contexts_post_prefetch: Dict[str, Multistreamable]
module_input_post_prefetch: Dict[str, Multistreamable]
class torchrec.distributed.train_pipeline.PrefetchTrainPipelineSparseDist(model: Module, optimizer: Optimizer, device: device, execute_all_batches: bool = True, apply_jit: bool = False)

Bases: TrainPipelineSparseDist[In, Out]

This pipeline overlaps device transfer, ShardedModule.input_dist(), and cache prefetching with forward and backward. This helps hide the all2all latency while preserving the training forward / backward ordering.

stage 4: forward, backward - uses default CUDA stream stage 3: prefetch - uses prefetch CUDA stream stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream stage 1: device transfer - uses memcpy CUDA stream

ShardedModule.input_dist() is only done for top-level modules in the call graph. To be considered a top-level module, a module can only depend on ‘getattr’ calls on input.

Input model must be symbolically traceable with the exception of ShardedModule and DistributedDataParallel modules.

Parameters:
  • model (torch.nn.Module) – model to pipeline.

  • optimizer (torch.optim.Optimizer) – optimizer to use.

  • device (torch.device) – device where device transfer, sparse data dist, prefetch, and forward/backward pass will happen.

  • execute_all_batches (bool) – executes remaining batches in pipeline after exhausting dataloader iterator.

  • apply_jit (bool) – apply torch.jit.script to non-pipelined (unsharded) modules.

progress(dataloader_iter: Iterator[In]) Out
class torchrec.distributed.train_pipeline.SplitsAllToAllAwaitable(input_tensors: List[Tensor], pg: ProcessGroup)

Bases: Awaitable[List[List[int]]]

class torchrec.distributed.train_pipeline.Tracer(leaf_modules: Optional[List[str]] = None)

Bases: Tracer

Disables proxying buffers during tracing. Ideally, proxying buffers would be disabled, but some models are currently mutating buffer values, which causes errors during tracing. If those models can be rewritten to not do that, we can likely remove this line.

graph: Graph
is_leaf_module(m: Module, module_qualified_name: str) bool

A method to specify whether a given nn.Module is a “leaf” module.

Leaf modules are the atomic units that appear in the IR, referenced by call_module calls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.

Parameters:
  • m (Module) – The module being queried about

  • module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule foo contains submodule bar, which contains submodule baz, that module will appear with the qualified name foo.bar.baz here.

Note

Backwards-compatibility for this API is guaranteed.

module_stack: OrderedDict[str, str]
node_name_to_scope: Dict[str, Tuple[str, type]]
proxy_buffer_attributes: bool = False
scope: Scope
class torchrec.distributed.train_pipeline.TrainPipeline(*args, **kwds)

Bases: ABC, Generic[In, Out]

abstract progress(dataloader_iter: Iterator[In]) Out
class torchrec.distributed.train_pipeline.TrainPipelineBase(model: Module, optimizer: Optimizer, device: device)

Bases: TrainPipeline[In, Out]

This class runs training iterations using a pipeline of two stages, each as a CUDA stream, namely, the current (default) stream and self._memcpy_stream. For each iteration, self._memcpy_stream moves the input from host (CPU) memory to GPU memory, and the default stream runs forward, backward, and optimization.

progress(dataloader_iter: Iterator[In]) Out
class torchrec.distributed.train_pipeline.TrainPipelineContext(input_dist_splits_requests: ~typing.Dict[str, ~torchrec.distributed.types.Awaitable[~typing.Any]] = <factory>, input_dist_tensors_requests: ~typing.Dict[str, ~torchrec.distributed.types.Awaitable[~typing.Any]] = <factory>, module_contexts: ~typing.Dict[str, ~torchrec.streamable.Multistreamable] = <factory>, module_contexts_next_batch: ~typing.Dict[str, ~torchrec.streamable.Multistreamable] = <factory>, fused_splits_awaitables: ~typing.List[~typing.Tuple[~typing.List[str], ~torchrec.distributed.train_pipeline.FusedKJTListSplitsAwaitable]] = <factory>)

Bases: object

Context information for a TrainPipelineSparseDist instance.

input_dist_splits_requests

Stores input dist requests in the splits awaitable stage, which occurs after starting the input dist.

Type:

Dict[str, Awaitable[Any]]

input_dist_tensors_requests

Stores input dist requests in the tensors awaitable stage, which occurs after calling wait() on the splits awaitable.

Type:

Dict[str, Awaitable[Any]]

module_contexts

Stores module contexts from the input dist for the current batch.

Type:

Dict[str, Multistreamable]

module_contexts_next_batch

Stores module contexts from the input dist for the next batch.

Type:

Dict[str, Multistreamable]

fused_splits_awaitables

List of fused splits input dist awaitable and the corresponding module names of each awaitable.

Type:

List[Tuple[List[str], FusedKJTListSplitsAwaitable]]

fused_splits_awaitables: List[Tuple[List[str], FusedKJTListSplitsAwaitable]]
input_dist_splits_requests: Dict[str, Awaitable[Any]]
input_dist_tensors_requests: Dict[str, Awaitable[Any]]
module_contexts: Dict[str, Multistreamable]
module_contexts_next_batch: Dict[str, Multistreamable]
class torchrec.distributed.train_pipeline.TrainPipelineSparseDist(model: Module, optimizer: Optimizer, device: device, execute_all_batches: bool = True, apply_jit: bool = False)

Bases: TrainPipeline[In, Out]

This pipeline overlaps device transfer, and ShardedModule.input_dist() with forward and backward. This helps hide the all2all latency while preserving the training forward / backward ordering.

stage 3: forward, backward - uses default CUDA stream stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream stage 1: device transfer - uses memcpy CUDA stream

ShardedModule.input_dist() is only done for top-level modules in the call graph. To be considered a top-level module, a module can only depend on ‘getattr’ calls on input.

Input model must be symbolically traceable with the exception of ShardedModule and DistributedDataParallel modules.

Parameters:
  • model (torch.nn.Module) – model to pipeline.

  • optimizer (torch.optim.Optimizer) – optimizer to use.

  • device (torch.device) – device where device transfer, sparse data dist, and forward/backward pass will happen.

  • execute_all_batches (bool) – executes remaining batches in pipeline after exhausting dataloader iterator.

  • apply_jit (bool) – apply torch.jit.script to non-pipelined (unsharded) modules.

progress(dataloader_iter: Iterator[In]) Out

torchrec.distributed.types

class torchrec.distributed.types.Awaitable

Bases: ABC, Generic[W]

property callbacks: List[Callable[[W], W]]
wait() W
class torchrec.distributed.types.BoundsCheckMode(value)

Bases: Enum

An enumeration.

FATAL = 0
IGNORE = 2
NONE = 3
WARNING = 1
class torchrec.distributed.types.CacheAlgorithm(value)

Bases: Enum

An enumeration.

LFU = 1
LRU = 0
class torchrec.distributed.types.CacheParams(algorithm: Union[torchrec.distributed.types.CacheAlgorithm, NoneType] = None, load_factor: Union[float, NoneType] = None, reserved_memory: Union[float, NoneType] = None, precision: Union[torchrec.distributed.types.DataType, NoneType] = None)

Bases: object

algorithm: Optional[CacheAlgorithm] = None
load_factor: Optional[float] = None
precision: Optional[DataType] = None
reserved_memory: Optional[float] = None
class torchrec.distributed.types.CommOp(value)

Bases: Enum

An enumeration.

POOLED_EMBEDDINGS_ALL_TO_ALL = 'pooled_embeddings_all_to_all'
POOLED_EMBEDDINGS_REDUCE_SCATTER = 'pooled_embeddings_reduce_scatter'
SEQUENCE_EMBEDDINGS_ALL_TO_ALL = 'sequence_embeddings_all_to_all'
class torchrec.distributed.types.ComputeKernel(value)

Bases: Enum

An enumeration.

DEFAULT = 'default'
class torchrec.distributed.types.DataType(value)

Bases: Enum

Our fusion implementation supports only certain types of data so it makes sense to retrict in a non-fused version as well.

BF16 = 'BF16'
FP16 = 'FP16'
FP32 = 'FP32'
INT2 = 'INT2'
INT32 = 'INT32'
INT4 = 'INT4'
INT64 = 'INT64'
INT8 = 'INT8'
UINT8 = 'UINT8'
class torchrec.distributed.types.EmbeddingModuleShardingPlan

Bases: ModuleShardingPlan, Dict[str, ParameterSharding]

Map of ParameterSharding per parameter (usually a table). This describes the sharding plan for a torchrec module (e.g. EmbeddingBagCollection)

class torchrec.distributed.types.GenericMeta

Bases: type

class torchrec.distributed.types.LazyAwaitable(*args, **kwargs)

Bases: Awaitable[W]

The LazyAwaitable type which exposes a wait() API, concrete types can control how to initialize and how the wait() behavior should be in order to achieve specific async operation.

This base LazyAwaitable type is a “lazy” async type, which means it will delay wait() as late as possible, see details in __torch_function__ below. This could help the model automatically enable computation and communication overlap, model author doesn’t need to manually call wait() if the results is used by a pytorch function, or by other python operations (NOTE: need to implement corresponding magic methods like __getattr__ below)

Some caveats:

  • This works with Pytorch functions, but not any generic method, if you would like to do arbitary python operations, you need to implement the corresponding magic methods

  • In the case that one function have two or more arguments are LazyAwaitable, the lazy wait mechanism can’t ensure perfect computation/communication overlap (i.e. quickly waited the first one but long wait on the second)

class torchrec.distributed.types.LazyNoWait(*args, **kwargs)

Bases: LazyAwaitable[W]

class torchrec.distributed.types.ModuleSharder(qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: ABC, Generic[M]

ModuleSharder is per each module, which supports sharding, e.g. EmbeddingBagCollection.

Args::

qcomm_codecs_registry (Optional[Dict[str, QuantizedCommCodecs]]) : Mapping of CommOp name to QuantizedCommCodecs

compute_kernels(sharding_type: str, compute_device_type: str) List[str]

List of supported compute kernels for a given sharding type and compute device.

abstract property module_type: Type[M]
property qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]]
abstract classmethod shard(module: M, params: EmbeddingModuleShardingPlan, env: ShardingEnv, device: Optional[device] = None) ShardedModule[Any, Any, Any, Any]

Does the actual sharding. It will allocate parameters on the requested locations as specified by corresponding ParameterSharding.

Default implementation is data-parallel replication.

Parameters:
  • module (M) – module to shard.

  • params (EmbeddingModuleShardingPlan) – dict of fully qualified parameter names (module path + parameter name, ‘.’-separated) to its sharding spec.

  • env (ShardingEnv) – sharding environment that has the process group.

  • device (torch.device) – compute device.

Returns:

sharded module implementation.

Return type:

ShardedModule[Any, Any, Any]

shardable_parameters(module: M) Dict[str, Parameter]

List of parameters that can be sharded.

sharding_types(compute_device_type: str) List[str]

List of supported sharding types. See ShardingType for well-known examples.

storage_usage(tensor: Tensor, compute_device_type: str, compute_kernel: str) Dict[str, int]

List of system resources and corresponding usage given a compute device and compute kernel.

class torchrec.distributed.types.ModuleShardingPlan

Bases: object

class torchrec.distributed.types.NoOpQuantizedCommCodec(*args, **kwds)

Bases: Generic[QuantizationContext]

Default No-Op implementation of QuantizedCommCodec

calc_quantized_size(input_len: int, ctx: Optional[QuantizationContext] = None) int
create_context() Optional[QuantizationContext]
decode(input_grad: Tensor, ctx: Optional[QuantizationContext] = None) Tensor
encode(input_tensor: Tensor, ctx: Optional[QuantizationContext] = None) Tensor
quantized_dtype() dtype
class torchrec.distributed.types.NoWait(obj: W)

Bases: Awaitable[W]

class torchrec.distributed.types.NullShardedModuleContext

Bases: Multistreamable

record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

class torchrec.distributed.types.NullShardingContext

Bases: Multistreamable

record_stream(stream: Stream) None

See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

class torchrec.distributed.types.ParameterSharding(sharding_type: str, compute_kernel: str, ranks: Optional[List[int]] = None, sharding_spec: Optional[ShardingSpec] = None, cache_params: Optional[CacheParams] = None, enforce_hbm: Optional[bool] = None, stochastic_rounding: Optional[bool] = None, bounds_check_mode: Optional[BoundsCheckMode] = None)

Bases: object

Describes the sharding of the parameter.

sharding_type (str): how this parameter is sharded. See ShardingType for well-known

types.

compute_kernel (str): compute kernel to be used by this parameter. ranks (Optional[List[int]]): rank of each shard. sharding_spec (Optional[ShardingSpec]): list of ShardMetadata for each shard. cache_params (Optional[CacheParams]): cache params for embedding lookup. enforce_hbm (Optional[bool]): whether to use HBM. stochastic_rounding (Optional[bool]): whether to use stochastic rounding. bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode.

Note

ShardingType.TABLE_WISE - rank where this embedding is placed ShardingType.COLUMN_WISE - rank where the embedding shards are placed, seen as individual tables ShardingType.TABLE_ROW_WISE - first rank when this embedding is placed ShardingType.ROW_WISE, ShardingType.DATA_PARALLEL - unused

bounds_check_mode: Optional[BoundsCheckMode] = None
cache_params: Optional[CacheParams] = None
compute_kernel: str
enforce_hbm: Optional[bool] = None
ranks: Optional[List[int]] = None
sharding_spec: Optional[ShardingSpec] = None
sharding_type: str
stochastic_rounding: Optional[bool] = None
class torchrec.distributed.types.ParameterStorage(value)

Bases: Enum

Well-known physical resources, which can be used as constraints by ShardingPlanner.

DDR = 'ddr'
HBM = 'hbm'
class torchrec.distributed.types.QuantizedCommCodec(*args, **kwds)

Bases: Generic[QuantizationContext]

Provide an implementation to quantized, or apply mixed precision, to the tensors used in collective calls (pooled_all_to_all, reduce_scatter, etc). The dtype is the dtype of the tensor called from encode.

This makes the assumption that the input tensor has type torch.float32

>>>
    quantized_tensor = quantized_comm_codec.encode(input_tensor)
    quantized_tensor.dtype == quantized_comm_codec.quantized_dtype
    collective_call(output_tensors, input_tensors=tensor)
    output_tensor = decode(output_tensors)

torch.assert_close(input_tensors, output_tensor)

calc_quantized_size(input_len: int, ctx: Optional[QuantizationContext] = None) int