• Docs >
  • torchrec.distributed.sharding
Shortcuts

torchrec.distributed.sharding

torchrec.distributed.sharding.cw_sharding

class torchrec.distributed.sharding.cw_sharding.BaseCwEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseTwEmbeddingSharding[C, F, T, W]

Base class for column-wise sharding.

embedding_dims() List[int]
embedding_names() List[str]
class torchrec.distributed.sharding.cw_sharding.CwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseCwEmbeddingSharding[EmptyShardingContext, SparseFeatures, Tensor, Tensor]

Shards embedding bags column-wise, i.e.. a given embedding table is partitioned along its columns and placed on specified ranks.

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[SparseFeatures]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[EmptyShardingContext, Tensor, Tensor]

torchrec.distributed.dist_data

class torchrec.distributed.dist_data.EmbeddingsAllToOne(device: device, world_size: int, cat_dim: int)

Bases: Module

Merges the pooled/sequence embedding tensor on each device into single tensor.

Parameters:
  • device (torch.device) – device on which buffer will be allocated.

  • world_size (int) – number of devices in the topology.

  • cat_dim (int) – which dimension you would like to concatenate on. For pooled embedding it is 1; for sequence embedding it is 0.

forward(tensors: List[Tensor]) Awaitable[Tensor]

Performs AlltoOne operation on pooled/sequence embeddings tensors.

Parameters:

tensors (List[torch.Tensor]) – list of embedding tensors.

Returns:

awaitable of the merged embeddings.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.dist_data.KJTAllToAll(pg: ProcessGroup, splits: List[int], device: Optional[device] = None, stagger: int = 1, variable_batch_size: bool = False)

Bases: Module

Redistributes KeyedJaggedTensor to a ProcessGroup according to splits.

Implementation utilizes AlltoAll collective as part of torch.distributed. Requires two collective calls, one to transmit final tensor lengths (to allocate correct space), and one to transmit actual sparse values.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • splits (List[int]) – List of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • stagger (int) – stagger value to apply to recat tensor, see _recat function for more detail.

  • variable_batch_size (bool) – whether variable batch size in each rank is enabled.

Example:

keys=['A','B','C']
splits=[2,1]
kjtA2A = KJTAllToAll(pg, splits, device)
awaitable = kjtA2A(rank0_input)

# where:
# rank0_input is KeyedJaggedTensor holding

#         0           1           2
# 'A'    [A.V0]       None        [A.V1, A.V2]
# 'B'    None         [B.V0]      [B.V1]
# 'C'    [C.V0]       [C.V1]      None

# rank1_input is KeyedJaggedTensor holding

#         0           1           2
# 'A'     [A.V3]      [A.V4]      None
# 'B'     None        [B.V2]      [B.V3, B.V4]
# 'C'     [C.V2]      [C.V3]      None

rank0_output = awaitable.wait()

# where:
# rank0_output is KeyedJaggedTensor holding

#         0           1           2           3           4           5
# 'A'     [A.V0]      None      [A.V1, A.V2]  [A.V3]      [A.V4]      None
# 'B'     None        [B.V0]    [B.V1]        None        [B.V2]      [B.V3, B.V4]

# rank1_output is KeyedJaggedTensor holding
#         0           1           2           3           4           5
# 'C'     [C.V0]      [C.V1]      None        [C.V2]      [C.V3]      None
forward(input: KeyedJaggedTensor) Awaitable[KJTAllToAllIndicesAwaitable]

Sends input to relevant ProcessGroup ranks. First wait will have lengths results and issue indices/weights AlltoAll. Second wait will have indices/weights results.

Parameters:

input (KeyedJaggedTensor) – input KeyedJaggedTensor of values to distribute.

Returns:

awaitable of a KeyedJaggedTensor.

Return type:

Awaitable[KeyedJaggedTensor]

training: bool
class torchrec.distributed.dist_data.KJTAllToAllIndicesAwaitable(pg: ProcessGroup, input: KeyedJaggedTensor, lengths: Tensor, splits: List[int], keys: List[str], recat: Tensor, in_lengths_per_worker: List[int], out_lengths_per_worker: List[int], batch_size_per_rank: List[int])

Bases: Awaitable[KeyedJaggedTensor]

Awaitable for KJT indices and weights All2All.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • input (KeyedJaggedTensor) – input KJT tensor.

  • lengths (torch.Tensor) – output lengths tensor.

  • splits (List[int]) – List of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • keys (List[str]) – KJT keys after AlltoAll.

  • recat (torch.Tensor) – recat tensor for reordering tensor order after AlltoAll.

  • in_lengths_per_worker (List[int]) – number of indices each rank will get.

  • out_lengths_per_worker (List[int]) – number of indices per rank in output.

  • batch_size_per_rank (List[int]) – batch size per rank, need to support variable batch size.

class torchrec.distributed.dist_data.KJTAllToAllLengthsAwaitable(pg: ProcessGroup, input: KeyedJaggedTensor, splits: List[int], keys: List[str], stagger: int, recat: Tensor, variable_batch_size: bool = False)

Bases: Awaitable[KJTAllToAllIndicesAwaitable]

Awaitable for KJT’s lengths AlltoAll.

wait() waits on lengths AlltoAll, then instantiates KJTAllToAllIndicesAwaitable awaitable where indices and weights AlltoAll will be issued.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • input (KeyedJaggedTensor) – input KJT tensor

  • splits (List[int]) – list of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • keys (List[str]) – KJT keys after AlltoAll

  • stagger (int) – stagger value to apply to recat tensor, see _recat function for more detail.

  • recat (torch.Tensor) – recat tensor for reordering tensor order after AlltoAll.

  • variable_batch_size (bool) – whether variable batch size is enabled.

class torchrec.distributed.dist_data.KJTOneToAll(splits: List[int], world_size: int)

Bases: Module

Redistributes KeyedJaggedTensor to all devices.

Implementation utilizes OnetoAll function, which essentially P2P copies the feature to the devices.

Parameters:
  • splits (List[int]) – lengths of features to split the KeyJaggedTensor features into before copying them.

  • world_size (int) – number of devices in the topology.

forward(kjt: KeyedJaggedTensor) Awaitable[List[KeyedJaggedTensor]]

Splits features first and then sends the slices to the corresponding devices.

Parameters:

kjt (KeyedJaggedTensor) – the input features.

Returns:

awaitable of KeyedJaggedTensor splits.

Return type:

Awaitable[List[KeyedJaggedTensor]]

training: bool
class torchrec.distributed.dist_data.PooledEmbeddingsAllGather(pg: ProcessGroup, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

The module class that wraps all-gather communication primitive for pooled embedding communication

We have a local input tensor with a layout of [batch_size, dimension]. We need to gather input tensors from all ranks into a flatten output tensor.

The class returns the async Awaitable handle for pooled embeddings tensor. The all-gather is only available for NCCL backend.

Parameters:

pg (dist.ProcessGroup) – The process group that the all-gather communication happens within.

Example:

init_distributed(rank=rank, size=2, backend="nccl")
pg = dist.new_group(backend="nccl")
input = torch.randn(2, 2)
m = PooledEmbeddingsAllGather(pg)
output = m(input)
tensor = output.wait()
forward(local_emb: Tensor) PooledEmbeddingsAwaitable

Performs reduce scatter operation on pooled embeddings tensor.

Parameters:

local_embs (torch.Tensor) – tensor of shape [num_buckets x batch_size, dimension].

Returns:

awaitable of pooled embeddings of tensor of shape [batch_size, dimension].

Return type:

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.PooledEmbeddingsAllToAll(pg: ProcessGroup, dim_sum_per_rank: List[int], device: Optional[device] = None, callbacks: Optional[List[Callable[[Tensor], Tensor]]] = None, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

Shards batches and collects keys of tensor with a ProcessGroup according to dim_sum_per_rank.

Implementation utilizes alltoall_pooled operation.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the embedding in each rank.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]) – callback functions.

Example:

dim_sum_per_rank = [2, 1]
a2a = PooledEmbeddingsAllToAll(pg, dim_sum_per_rank, device)

t0 = torch.rand((6, 2))
t1 = torch.rand((6, 1))
rank0_output = a2a(t0).wait()
rank1_output = a2a(t1).wait()
print(rank0_output.size())
    # torch.Size([3, 3])
print(rank1_output.size())
    # torch.Size([3, 3])
property callbacks: List[Callable[[Tensor], Tensor]]
forward(local_embs: Tensor, batch_size_per_rank: Optional[List[int]] = None) PooledEmbeddingsAwaitable

Performs AlltoAll pooled operation on pooled embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – tensor of values to distribute.

  • batch_size_per_rank (Optional[List[int]]) – batch size per rank, to support variable batch size.

Returns:

awaitable of pooled embeddings.

Return type:

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.PooledEmbeddingsAwaitable(tensor_awaitable: Awaitable[Tensor])

Bases: Awaitable[Tensor]

Awaitable for pooled embeddings after collective operation.

Parameters:

tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.

property callbacks: List[Callable[[Tensor], Tensor]]
class torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter(pg: ProcessGroup, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

The module class that wraps reduce-scatter communication primitives for pooled embedding communication in row-wise and twrw sharding.

For pooled embeddings, we have a local model-parallel output tensor with a layout of [num_buckets x batch_size, dimension]. We need to sum over num_buckets dimension across batches. We split tensor along the first dimension into unequal chunks (tensor slices of different buckets) according to input_splits and reduce them into the output tensor and scatter the results for corresponding ranks.

The class returns the async Awaitable handle for pooled embeddings tensor. The reduce-scatter-v is only available for NCCL backend.

Parameters:
  • pg (dist.ProcessGroup) – The process group that the reduce-scatter communication happens within.

  • codecs (Optional[QuantizedCommCodecs]) – Quantization codec

Example:

init_distributed(rank=rank, size=2, backend="nccl")
pg = dist.new_group(backend="nccl")
input = torch.randn(2 * 2, 2)
input_splits = [1,3]
m = PooledEmbeddingsReduceScatter(pg)
output = m(input, input_splits)
tensor = output.wait()
forward(local_embs: Tensor, input_splits: Optional[List[int]] = None) PooledEmbeddingsAwaitable

Performs reduce scatter operation on pooled embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – tensor of shape [num_buckets x batch_size, dimension].

  • input_splits (Optional[List[int]]) – list of splits for local_embs dim0.

Returns:

awaitable of pooled embeddings of tensor of shape [batch_size, dimension].

Return type:

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.SeqEmbeddingsAllToOne(device: device, world_size: int)

Bases: Module

Merges the pooled/sequence embedding tensor on each device into single tensor.

Parameters:
  • device (torch.device) – device on which buffer will be allocated

  • world_size (int) – number of devices in the topology.

  • cat_dim (int) – which dimension you like to concate on. For pooled embedding it is 1; for sequence embedding it is 0.

forward(tensors: List[Tensor]) Awaitable[List[Tensor]]

Performs AlltoOne operation on pooled embeddings tensors.

Parameters:

tensors (List[torch.Tensor]) – list of pooled embedding tensors.

Returns:

awaitable of the merged pooled embeddings.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.dist_data.SequenceEmbeddingsAllToAll(pg: ProcessGroup, features_per_rank: List[int], device: Optional[device] = None, codecs: Optional[QuantizedCommCodecs] = None)

Bases: Module

Redistributes sequence embedding to a ProcessGroup according to splits.

Parameters:
  • pg (dist.ProcessGroup) – the process group that the AlltoAll communication happens within.

  • features_per_rank (List[int]) – list of number of features per rank.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

Example:

init_distributed(rank=rank, size=2, backend="nccl")
pg = dist.new_group(backend="nccl")
features_per_rank = [4, 4]
m = SequenceEmbeddingsAllToAll(pg, features_per_rank)
local_embs = torch.rand((6, 2))
sharding_ctx: SequenceShardingContext
output = m(
    local_embs=local_embs,
    lengths=sharding_ctx.lengths_after_input_dist,
    input_splits=sharding_ctx.input_splits,
    output_splits=sharding_ctx.output_splits,
    unbucketize_permute_tensor=None,
)
tensor = output.wait()
forward(local_embs: Tensor, lengths: Tensor, input_splits: List[int], output_splits: List[int], unbucketize_permute_tensor: Optional[Tensor] = None) SequenceEmbeddingsAwaitable

Performs AlltoAll operation on sequence embeddings tensor.

Parameters:
  • local_embs (torch.Tensor) – input embeddings tensor.

  • lengths (torch.Tensor) – lengths of sparse features after AlltoAll.

  • input_splits (List[int]) – input splits of AlltoAll.

  • output_splits (List[int]) – output splits of AlltoAll.

  • unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of the KJT bucketize (for row-wise sharding only).

Returns:

awaitable of sequence embeddings.

Return type:

SequenceEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.SequenceEmbeddingsAwaitable(tensor_awaitable: Awaitable[Tensor], unbucketize_permute_tensor: Optional[Tensor], embedding_dim: int)

Bases: Awaitable[Tensor]

Awaitable for sequence embeddings after collective operation.

Parameters:
  • tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.

  • unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of KJT bucketize (for row-wise sharding only).

  • embedding_dim (int) – embedding dimension.

torchrec.distributed.sharding.dp_sharding

class torchrec.distributed.sharding.dp_sharding.BaseDpEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None)

Bases: EmbeddingSharding[C, F, T, W]

Base class for data-parallel sharding.

embedding_dims() List[int]
embedding_names() List[str]
embedding_names_per_rank() List[List[str]]
embedding_shard_metadata() List[Optional[ShardMetadata]]
id_list_feature_names() List[str]
id_score_list_feature_names() List[str]
class torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingDist

Bases: BaseEmbeddingDist[EmptyShardingContext, Tensor, Tensor]

Distributes pooled embeddings to be data-parallel.

forward(local_embs: Tensor, sharding_ctx: Optional[EmptyShardingContext] = None) Awaitable[Tensor]

No-op as pooled embeddings are already distributed in data-parallel fashion.

Parameters:

local_embs (torch.Tensor) – output sequence embeddings.

Returns:

awaitable of pooled embeddings tensor.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None)

Bases: BaseDpEmbeddingSharding[EmptyShardingContext, SparseFeatures, Tensor, Tensor]

Shards embedding bags data-parallel, with no table sharding i.e.. a given embedding table is replicated across all ranks.

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[SparseFeatures]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[EmptyShardingContext, Tensor, Tensor]
class torchrec.distributed.sharding.dp_sharding.DpSparseFeaturesDist

Bases: BaseSparseFeaturesDist[SparseFeatures]

Distributes sparse features (input) to be data-parallel.

forward(sparse_features: SparseFeatures) Awaitable[Awaitable[SparseFeatures]]

No-op as sparse features are already distributed in data-parallel fashion.

Parameters:

sparse_features (SparseFeatures) – input sparse features.

Returns:

awaitable of awaitable of SparseFeatures.

Return type:

Awaitable[Awaitable[SparseFeatures]]

training: bool

torchrec.distributed.sharding.rw_sharding

class torchrec.distributed.sharding.rw_sharding.BaseRwEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: EmbeddingSharding[C, F, T, W]

Base class for row-wise sharding.

embedding_dims() List[int]
embedding_names() List[str]
embedding_names_per_rank() List[List[str]]
embedding_shard_metadata() List[Optional[ShardMetadata]]
id_list_feature_names() List[str]
id_score_list_feature_names() List[str]
class torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingDist(pg: ProcessGroup, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseEmbeddingDist[EmptyShardingContext, Tensor, Tensor]

Redistributes pooled embedding tensor in RW fashion by performing a reduce-scatter operation.

Parameters:

pg (dist.ProcessGroup) – ProcessGroup for reduce-scatter communication.

forward(local_embs: Tensor, sharding_ctx: Optional[EmptyShardingContext] = None) Awaitable[Tensor]

Performs reduce-scatter pooled operation on pooled embeddings tensor.

Parameters:

local_embs (torch.Tensor) – pooled embeddings tensor to distribute.

Returns:

awaitable of pooled embeddings tensor.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseRwEmbeddingSharding[EmptyShardingContext, SparseFeatures, Tensor, Tensor]

Shards embedding bags row-wise, i.e.. a given embedding table is evenly distributed by rows and table slices are placed on all ranks.

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[SparseFeatures]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[EmptyShardingContext, Tensor, Tensor]
class torchrec.distributed.sharding.rw_sharding.RwSparseFeaturesDist(pg: ProcessGroup, num_id_list_features: int, num_id_score_list_features: int, id_list_feature_hash_sizes: List[int], id_score_list_feature_hash_sizes: List[int], device: Optional[device] = None, is_sequence: bool = False, has_feature_processor: bool = False, need_pos: bool = False)

Bases: BaseSparseFeaturesDist[SparseFeatures]

Bucketizes sparse features in RW fashion and then redistributes with an AlltoAll collective operation.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • intra_pg (dist.ProcessGroup) – ProcessGroup within single host group for AlltoAll communication.

  • num_id_list_features (int) – total number of id list features.

  • num_id_score_list_features (int) – total number of id score list features

  • id_list_feature_hash_sizes (List[int]) – hash sizes of id list features.

  • id_score_list_feature_hash_sizes (List[int]) – hash sizes of id score list features.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • is_sequence (bool) – if this is for a sequence embedding.

  • has_feature_processor (bool) – existence of feature processor (ie. position weighted features).

forward(sparse_features: SparseFeatures) Awaitable[Awaitable[SparseFeatures]]

Bucketizes sparse feature values into world size number of buckets and then performs AlltoAll operation.

Parameters:

sparse_features (SparseFeatures) – sparse features to bucketize and redistribute.

Returns:

awaitable of SparseFeatures.

Return type:

Awaitable[SparseFeatures]

training: bool

torchrec.distributed.sharding.tw_sharding

class torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: EmbeddingSharding[C, F, T, W]

Base class for table wise sharding.

embedding_dims() List[int]
embedding_names() List[str]
embedding_names_per_rank() List[List[str]]
embedding_shard_metadata() List[Optional[ShardMetadata]]
id_list_feature_names() List[str]
id_list_feature_names_per_rank() List[List[str]]
id_list_features_per_rank() List[int]
id_score_list_feature_names() List[str]
id_score_list_feature_names_per_rank() List[List[str]]
id_score_list_features_per_rank() List[int]
class torchrec.distributed.sharding.tw_sharding.InferTwEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseTwEmbeddingSharding[EmptyShardingContext, SparseFeaturesList, List[Tensor], Tensor]

Shards embedding bags table-wise for inference

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[SparseFeaturesList]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup[SparseFeaturesList, List[Tensor]]
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[EmptyShardingContext, List[Tensor], Tensor]
class torchrec.distributed.sharding.tw_sharding.InferTwPooledEmbeddingDist(device: device, world_size: int)

Bases: BaseEmbeddingDist[EmptyShardingContext, List[Tensor], Tensor]

Merges pooled embedding tensor from each device for inference.

Parameters:
  • device (Optional[torch.device]) – device on which buffer will be allocated.

  • world_size (int) – number of devices in the topology.

forward(local_embs: List[Tensor], sharding_ctx: Optional[EmptyShardingContext] = None) Awaitable[Tensor]

Performs AlltoOne operation on pooled embedding tensors.

Parameters:

local_embs (List[torch.Tensor]) – pooled embedding tensors with len(local_embs) == world_size.

Returns:

awaitable of merged pooled embedding tensor.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.tw_sharding.InferTwSparseFeaturesDist(id_list_features_per_rank: List[int], id_score_list_features_per_rank: List[int], world_size: int)

Bases: BaseSparseFeaturesDist[SparseFeaturesList]

Redistributes sparse features to all devices for inference.

Parameters:
  • id_list_features_per_rank (List[int]) – number of id list features to send to each rank.

  • id_score_list_features_per_rank (List[int]) – number of id score list features to send to each rank.

  • world_size (int) – number of devices in the topology.

forward(sparse_features: SparseFeatures) Awaitable[Awaitable[SparseFeaturesList]]

Performs OnetoAll operation on sparse features.

Parameters:

sparse_features (SparseFeatures) – sparse features to redistribute.

Returns:

awaitable of awaitable of SparseFeatures.

Return type:

Awaitable[Awaitable[SparseFeatures]]

training: bool
class torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingDist(pg: ProcessGroup, dim_sum_per_rank: List[int], device: Optional[device] = None, callbacks: Optional[List[Callable[[Tensor], Tensor]]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseEmbeddingDist[EmptyShardingContext, Tensor, Tensor]

Redistributes pooled embedding tensor with an AlltoAll collective operation for table wise sharding.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the embedding in each rank.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

forward(local_embs: Tensor, sharding_ctx: Optional[EmptyShardingContext] = None) Awaitable[Tensor]

Performs AlltoAll operation on pooled embeddings tensor.

Parameters:

local_embs (torch.Tensor) – tensor of values to distribute.

Returns:

awaitable of pooled embeddings.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseTwEmbeddingSharding[EmptyShardingContext, SparseFeatures, Tensor, Tensor]

Shards embedding bags table-wise, i.e.. a given embedding table is entirely placed on a selected rank.

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[SparseFeatures]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[EmptyShardingContext, Tensor, Tensor]
class torchrec.distributed.sharding.tw_sharding.TwSparseFeaturesDist(pg: ProcessGroup, id_list_features_per_rank: List[int], id_score_list_features_per_rank: List[int], device: Optional[device] = None)

Bases: BaseSparseFeaturesDist[SparseFeatures]

Redistributes sparse features with an AlltoAll collective operation for table wise sharding.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • id_list_features_per_rank (List[int]) – number of id list features to send to each rank.

  • id_score_list_features_per_rank (List[int]) – number of id score list features to send to each rank.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

forward(sparse_features: SparseFeatures) Awaitable[Awaitable[SparseFeatures]]

Performs AlltoAll operation on sparse features.

Parameters:

sparse_features (SparseFeatures) – sparse features to redistribute.

Returns:

awaitable of awaitable of SparseFeatures.

Return type:

Awaitable[Awaitable[SparseFeatures]]

training: bool

torchrec.distributed.sharding.twcw_sharding

class torchrec.distributed.sharding.twcw_sharding.TwCwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: CwPooledEmbeddingSharding

Shards embedding bags table-wise column-wise, i.e.. a given embedding table is partitioned along its columns and the table slices are placed on all ranks within a host group.

torchrec.distributed.sharding.twrw_sharding

class torchrec.distributed.sharding.twrw_sharding.BaseTwRwEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: EmbeddingSharding[C, F, T, W]

Base class for table wise row wise sharding.

embedding_dims() List[int]
embedding_names() List[str]
embedding_names_per_rank() List[List[str]]
embedding_shard_metadata() List[Optional[ShardMetadata]]
id_list_feature_names() List[str]
id_score_list_feature_names() List[str]
class torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingDist(cross_pg: ProcessGroup, intra_pg: ProcessGroup, dim_sum_per_node: List[int], device: Optional[device] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseEmbeddingDist[EmptyShardingContext, Tensor, Tensor]

Redistributes pooled embedding tensor in TWRW fashion by performing a reduce-scatter operation row wise on the host level and then an AlltoAll operation table wise on the global level.

Parameters:
  • cross_pg (dist.ProcessGroup) – global level ProcessGroup for AlltoAll communication.

  • intra_pg (dist.ProcessGroup) – host level ProcessGroup for reduce-scatter communication.

  • dim_sum_per_node (List[int]) – number of features (sum of dimensions) of the embedding for each host.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

forward(local_embs: Tensor, sharding_ctx: Optional[EmptyShardingContext] = None) Awaitable[Tensor]

Performs reduce-scatter pooled operation on pooled embeddings tensor followed by AlltoAll pooled operation.

Parameters:

local_embs (torch.Tensor) – pooled embeddings tensor to distribute.

Returns:

awaitable of pooled embeddings tensor.

Return type:

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingSharding(sharding_infos: List[EmbeddingShardingInfo], env: ShardingEnv, device: Optional[device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None)

Bases: BaseTwRwEmbeddingSharding[EmptyShardingContext, SparseFeatures, Tensor, Tensor]

Shards embedding bags table-wise then row-wise.

create_input_dist(device: Optional[device] = None) BaseSparseFeaturesDist[SparseFeatures]
create_lookup(device: Optional[device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[BaseGroupedFeatureProcessor] = None) BaseEmbeddingLookup
create_output_dist(device: Optional[device] = None) BaseEmbeddingDist[EmptyShardingContext, Tensor, Tensor]
class torchrec.distributed.sharding.twrw_sharding.TwRwSparseFeaturesDist(pg: ProcessGroup, intra_pg: ProcessGroup, id_list_features_per_rank: List[int], id_score_list_features_per_rank: List[int], id_list_feature_hash_sizes: List[int], id_score_list_feature_hash_sizes: List[int], device: Optional[device] = None, has_feature_processor: bool = False, need_pos: bool = False)

Bases: BaseSparseFeaturesDist[SparseFeatures]

Bucketizes sparse features in TWRW fashion and then redistributes with an AlltoAll collective operation.

Parameters:
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • intra_pg (dist.ProcessGroup) – ProcessGroup within single host group for AlltoAll communication.

  • id_list_features_per_rank (List[int]) – number of id list features to send to each rank.

  • id_score_list_features_per_rank (List[int]) – number of id score list features to send to each rank.

  • id_list_feature_hash_sizes (List[int]) – hash sizes of id list features.

  • id_score_list_feature_hash_sizes (List[int]) – hash sizes of id score list features.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • has_feature_processor (bool) – existence of a feature processor (ie. position weighted features).

Example:

3 features
2 hosts with 2 devices each

Bucketize each feature into 2 buckets
Staggered shuffle with feature splits [2, 1]
AlltoAll operation

NOTE: result of staggered shuffle and AlltoAll operation look the same after
reordering in AlltoAll

Result:
    host 0 device 0:
        feature 0 bucket 0
        feature 1 bucket 0

    host 0 device 1:
        feature 0 bucket 1
        feature 1 bucket 1

    host 1 device 0:
        feature 2 bucket 0

    host 1 device 1:
        feature 2 bucket 1
forward(sparse_features: SparseFeatures) Awaitable[Awaitable[SparseFeatures]]

Bucketizes sparse feature values into local world size number of buckets, performs staggered shuffle on the sparse features, and then performs AlltoAll operation.

Parameters:

sparse_features (SparseFeatures) – sparse features to bucketize and redistribute.

Returns:

awaitable of SparseFeatures.

Return type:

Awaitable[SparseFeatures]

training: bool

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources