• Docs >
  • torchrec.distributed.sharding
Shortcuts

torchrec.distributed.sharding

torchrec.distributed.sharding.cw_sharding

class torchrec.distributed.sharding.cw_sharding.BaseCwEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None, permute_embeddings: bool = False)

Bases: torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding[torchrec.distributed.sharding.cw_sharding.F, torchrec.distributed.sharding.cw_sharding.T]

base class for column-wise sharding

embedding_dims() List[int]
embedding_names() List[str]
class torchrec.distributed.sharding.cw_sharding.CwPooledEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None, permute_embeddings: bool = False)

Bases: torchrec.distributed.sharding.cw_sharding.BaseCwEmbeddingSharding[torchrec.distributed.embedding_types.SparseFeatures, torch.Tensor]

Shards embedding bags column-wise, i.e.. a given embedding table is entirely placed on a selected rank.

create_input_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_types.SparseFeatures]
create_lookup(device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor] = None) torchrec.distributed.embedding_types.BaseEmbeddingLookup
create_output_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseEmbeddingDist[torch.Tensor]

torchrec.distributed.dist_data

class torchrec.distributed.dist_data.EmbeddingsAllToOne(device: torch.device, world_size: int, cat_dim: int)

Bases: torch.nn.modules.module.Module

Merges the pooled/sequence embedding tensor on each device into single tensor.

Parameters
  • device (torch.device) – device on which buffer will be allocated

  • world_size (int) – number of devices in the topology.

  • cat_dim (int) – which dimension you like to concate on. For pooled embedding it is 1; for sequence embedding it is 0.

forward(tensors: List[torch.Tensor]) torchrec.distributed.types.Awaitable[torch.Tensor]

Performs AlltoOne operation on pooled embeddings tensors.

Parameters

tensors (List[torch.Tensor]) – list of pooled embedding tensors.

Returns

awaitable of the merged pooled embeddings.

Return type

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.dist_data.KJTAllToAll(pg: torch._C._distributed_c10d.ProcessGroup, splits: List[int], device: Optional[torch.device] = None, stagger: int = 1, variable_batch_size: bool = False)

Bases: torch.nn.modules.module.Module

Redistributes KeyedJaggedTensor to a ProcessGroup according to splits.

Implementation utilizes AlltoAll collective as part of torch.distributed. Requires two collective calls, one to transmit final tensor lengths (to allocate correct space), and one to transmit actual sparse values.

Parameters
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • splits (List[int]) – List of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • stagger (int) – stagger value to apply to recat tensor, see _recat function for more detail.

  • variable_batch_size (bool) – variable batch size in each rank

Example:

keys=['A','B','C']
splits=[2,1]
kjtA2A = KJTAllToAll(pg, splits, device)
awaitable = kjtA2A(rank0_input)

# where:
# rank0_input is KeyedJaggedTensor holding

#         0           1           2
# 'A'    [A.V0]       None        [A.V1, A.V2]
# 'B'    None         [B.V0]      [B.V1]
# 'C'    [C.V0]       [C.V1]      None

# rank1_input is KeyedJaggedTensor holding

#         0           1           2
# 'A'     [A.V3]      [A.V4]      None
# 'B'     None        [B.V2]      [B.V3, B.V4]
# 'C'     [C.V2]      [C.V3]      None

rank0_output = awaitable.wait()

# where:
# rank0_output is KeyedJaggedTensor holding

#         0           1           2           3           4           5
# 'A'     [A.V0]      None      [A.V1, A.V2]  [A.V3]      [A.V4]      None
# 'B'     None        [B.V0]    [B.V1]        None        [B.V2]      [B.V3, B.V4]

# rank1_output is KeyedJaggedTensor holding
#         0           1           2           3           4           5
# 'C'     [C.V0]      [C.V1]      None        [C.V2]      [C.V3]      None
forward(input: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torchrec.distributed.types.Awaitable[torchrec.distributed.dist_data.KJTAllToAllIndicesAwaitable]

Sends input to relevant ProcessGroup ranks. First wait will have lengths results and issue indices/weights AlltoAll. Second wait will have indices/weights results.

Parameters

input (KeyedJaggedTensor) – input KeyedJaggedTensor of values to distribute.

Returns

awaitable of a KeyedJaggedTensor.

Return type

Awaitable[KeyedJaggedTensor]

training: bool
class torchrec.distributed.dist_data.KJTAllToAllIndicesAwaitable(pg: torch._C._distributed_c10d.ProcessGroup, input: torchrec.sparse.jagged_tensor.KeyedJaggedTensor, lengths: torch.Tensor, splits: List[int], keys: List[str], recat: torch.Tensor, in_lengths_per_worker: List[int], out_lengths_per_worker: List[int], batch_size_per_rank: List[int])

Bases: torchrec.distributed.types.Awaitable[torchrec.sparse.jagged_tensor.KeyedJaggedTensor]

Awaitable for KJT indices and weights All2All.

Parameters
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • input (KeyedJaggedTensor) – Input KJT tensor.

  • lengths – Output lengths tensor

  • splits (List[int]) – List of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • keys (List[str]) – KJT keys after AlltoAll.

  • recat (torch.Tensor) – recat tensor for reordering tensor order after AlltoAll.

  • in_lengths_per_worker (List[str]) – indices number of indices each rank will get.

class torchrec.distributed.dist_data.KJTAllToAllLengthsAwaitable(pg: torch._C._distributed_c10d.ProcessGroup, input: torchrec.sparse.jagged_tensor.KeyedJaggedTensor, splits: List[int], keys: List[str], stagger: int, recat: torch.Tensor, variable_batch_size: bool = False)

Bases: torchrec.distributed.types.Awaitable[torchrec.distributed.dist_data.KJTAllToAllIndicesAwaitable]

Awaitable for KJT’s lengths AlltoAll.

wait() waits on lengths AlltoAll, then instantiates KJTAllToAllIndicesAwaitable awaitable where indices and weights AlltoAll will be issued.

Parameters
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • input (KeyedJaggedTensor) – Input KJT tensor

  • splits (List[int]) – List of len(pg.size()) which indicates how many features to send to each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. Same for all ranks.

  • keys (List[str]) – KJT keys after AlltoAll

  • recat (torch.Tensor) – recat tensor for reordering tensor order after AlltoAll.

class torchrec.distributed.dist_data.KJTOneToAll(splits: List[int], world_size: int)

Bases: torch.nn.modules.module.Module

Redistributes KeyedJaggedTensor to all devices.

Implementation utilizes OnetoAll function, which essentially P2P copies the feature to the devices.

Parameters
  • splits (List[int]) – lengths of features to split the KeyJaggedTensor features into before copying them.

  • world_size (int) – number of devices in the topology.

  • recat (torch.Tensor) – recat tensor for reordering tensor order after AlltoAll.

forward(kjt: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torchrec.distributed.types.Awaitable[List[torchrec.sparse.jagged_tensor.KeyedJaggedTensor]]

Splits features first and then sends the slices to the corresponding devices.

Parameters

kjt (KeyedJaggedTensor) – the input features.

Returns

awaitable of KeyedJaggedTensor splits.

Return type

Awaitable[List[KeyedJaggedTensor]]

training: bool
class torchrec.distributed.dist_data.PooledEmbeddingsAllToAll(pg: torch._C._distributed_c10d.ProcessGroup, dim_sum_per_rank: List[int], device: Optional[torch.device] = None, callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None)

Bases: torch.nn.modules.module.Module

Shards batches and collects keys of tensor with a ProcessGroup according to dim_sum_per_rank.

Implementation utilizes alltoall_pooled operation.

Parameters
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the embedding in each rank.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]) –

Example:

dim_sum_per_rank = [2, 1]
a2a = PooledEmbeddingsAllToAll(pg, dim_sum_per_rank, device)

t0 = torch.rand((6, 2))
t1 = torch.rand((6, 1))
rank0_output = a2a(t0).wait()
rank1_output = a2a(t1).wait()
print(rank0_output.size())
    # torch.Size([3, 3])
print(rank1_output.size())
    # torch.Size([3, 3])
property callbacks: List[Callable[[torch.Tensor], torch.Tensor]]
forward(local_embs: torch.Tensor, batch_size_per_rank: Optional[List[int]] = None) torchrec.distributed.dist_data.PooledEmbeddingsAwaitable

Performs AlltoAll pooled operation on pooled embeddings tensor.

Parameters

local_embs (torch.Tensor) – tensor of values to distribute.

Returns

awaitable of pooled embeddings.

Return type

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.PooledEmbeddingsAwaitable(tensor_awaitable: torchrec.distributed.types.Awaitable[torch.Tensor])

Bases: torchrec.distributed.types.Awaitable[torch.Tensor]

Awaitable for pooled embeddings after collective operation.

Parameters

tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.

property callbacks: List[Callable[[torch.Tensor], torch.Tensor]]
class torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter(pg: torch._C._distributed_c10d.ProcessGroup)

Bases: torch.nn.modules.module.Module

The module class that wraps reduce-scatter communication primitive for pooled embedding communication in row-wise and twrw sharding.

For pooled embeddings, we have a local model-parallel output tensor with a layout of [num_buckets x batch_size, dimension]. We need to sum over num_buckets dimension across batches. We split tensor along the first dimension into equal chunks (tensor slices of different buckets) and reduce them into the output tensor and scatter the results for corresponding ranks.

The class returns the async Awaitable handle for pooled embeddings tensor. The reduce-scatter is only available for NCCL backend.

Parameters

pg (dist.ProcessGroup) – The process group that the reduce-scatter communication happens within.

Example:

init_distributed(rank=rank, size=2, backend="nccl")
pg = dist.new_group(backend="nccl")
input = torch.randn(2 * 2, 2)
m = PooledEmbeddingsReduceScatter(pg)
output = m(input)
tensor = output.wait()
forward(local_embs: torch.Tensor) torchrec.distributed.dist_data.PooledEmbeddingsAwaitable

Performs reduce scatter operation on pooled embeddings tensor.

Parameters

local_embs (torch.Tensor) – tensor of shape [num_buckets x batch_size, dimension].

Returns

awaitable of pooled embeddings of tensor of shape [batch_size, dimension].

Return type

PooledEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.SequenceEmbeddingAllToAll(pg: torch._C._distributed_c10d.ProcessGroup, features_per_rank: List[int], device: Optional[torch.device] = None)

Bases: torch.nn.modules.module.Module

Redistributes sequence embedding to a ProcessGroup according to splits.

Parameters
  • pg (dist.ProcessGroup) – the process group that the AlltoAll communication happens within.

  • features_per_rank (List[int]) – List of number of features per rank.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

Example:

init_distributed(rank=rank, size=2, backend="nccl")
pg = dist.new_group(backend="nccl")
features_per_rank = [4, 4]
m = SequenceEmbeddingAllToAll(pg, features_per_rank)
local_embs = torch.rand((6, 2))
sharding_ctx: SequenceShardingContext
output = m(
    local_embs=local_embs,
    lengths=sharding_ctx.lengths_after_input_dist,
    input_splits=sharding_ctx.input_splits,
    output_splits=sharding_ctx.output_splits,
    unbucketize_permute_tensor=None,
)
tensor = output.wait()
forward(local_embs: torch.Tensor, lengths: torch.Tensor, input_splits: List[int], output_splits: List[int], unbucketize_permute_tensor: Optional[torch.Tensor] = None) torchrec.distributed.dist_data.SequenceEmbeddingsAwaitable

Performs AlltoAll operation on sequence embeddings tensor.

Parameters
  • local_embs (torch.Tensor) – input embeddings tensor.

  • lengths (torch.Tensor) – lengths of sparse features after AlltoAll.

  • input_splits (List[int]) – input splits of AlltoAll.

  • output_splits (List[int]) – output splits of AlltoAll.

  • unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of the KJT bucketize (for row-wise sharding only).

Returns

SequenceEmbeddingsAwaitable

training: bool
class torchrec.distributed.dist_data.SequenceEmbeddingsAwaitable(tensor_awaitable: torchrec.distributed.types.Awaitable[torch.Tensor], unbucketize_permute_tensor: Optional[torch.Tensor], embedding_dim: int)

Bases: torchrec.distributed.types.Awaitable[torch.Tensor]

Awaitable for sequence embeddings after collective operation.

Parameters
  • tensor_awaitable (Awaitable[torch.Tensor]) – awaitable of concatenated tensors from all the processes in the group after collective.

  • unbucketize_permute_tensor (Optional[torch.Tensor]) – stores the permute order of KJT bucketize (for row-wise sharding only).

torchrec.distributed.sharding.dp_sharding

class torchrec.distributed.sharding.dp_sharding.BaseDpEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None)

Bases: torchrec.distributed.embedding_sharding.EmbeddingSharding[torchrec.distributed.sharding.dp_sharding.F, torchrec.distributed.sharding.dp_sharding.T]

base class for data-parallel sharding

embedding_dims() List[int]
embedding_names() List[str]
embedding_shard_metadata() List[Optional[torch.distributed._shard.metadata.ShardMetadata]]
id_list_feature_names() List[str]
id_score_list_feature_names() List[str]
class torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingDist

Bases: torchrec.distributed.embedding_sharding.BaseEmbeddingDist[torch.Tensor]

Distributes pooled embeddings to be data-parallel.

forward(local_embs: torch.Tensor) torchrec.distributed.types.Awaitable[torch.Tensor]

No-op as pooled embeddings are already distributed in data-parallel fashion.

Call Args:

local_embs (torch.Tensor): output sequence embeddings.

Returns

awaitable of pooled embeddings tensor.

Return type

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None)

Bases: torchrec.distributed.sharding.dp_sharding.BaseDpEmbeddingSharding[torchrec.distributed.embedding_types.SparseFeatures, torch.Tensor]

Shards embedding bags using data-parallel, with no table sharding i.e.. a given embedding table is replicated across all ranks.

create_input_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_types.SparseFeatures]
create_lookup(device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor] = None) torchrec.distributed.embedding_types.BaseEmbeddingLookup
create_output_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseEmbeddingDist[torch.Tensor]
class torchrec.distributed.sharding.dp_sharding.DpSparseFeaturesDist

Bases: torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_types.SparseFeatures]

Distributes sparse features (input) to be data-parallel.

forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torchrec.distributed.types.Awaitable[torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeatures]]

No-op as sparse features are already distributed in data-parallel fashion.

Call Args:

sparse_features (SparseFeatures): input sparse features.

Returns

wait twice to get sparse features.

Return type

Awaitable[Awaitable[SparseFeatures]]

training: bool

torchrec.distributed.sharding.rw_sharding

class torchrec.distributed.sharding.rw_sharding.BaseRwEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None)

Bases: torchrec.distributed.embedding_sharding.EmbeddingSharding[torchrec.distributed.sharding.rw_sharding.F, torchrec.distributed.sharding.rw_sharding.T]

base class for row-wise sharding

embedding_dims() List[int]
embedding_names() List[str]
embedding_shard_metadata() List[Optional[torch.distributed._shard.metadata.ShardMetadata]]
id_list_feature_names() List[str]
id_score_list_feature_names() List[str]
class torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingDist(pg: torch._C._distributed_c10d.ProcessGroup)

Bases: torchrec.distributed.embedding_sharding.BaseEmbeddingDist[torch.Tensor]

Redistributes pooled embedding tensor in RW fashion by performing a reduce-scatter operation.

Parameters

pg (dist.ProcessGroup) – ProcessGroup for reduce-scatter communication.

forward(local_embs: torch.Tensor) torchrec.distributed.types.Awaitable[torch.Tensor]

Performs reduce-scatter pooled operation on pooled embeddings tensor.

Parameters

local_embs (torch.Tensor) – pooled embeddings tensor to distribute.

Returns

awaitable of pooled embeddings tensor.

Return type

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None)

Bases: torchrec.distributed.sharding.rw_sharding.BaseRwEmbeddingSharding[torchrec.distributed.embedding_types.SparseFeatures, torch.Tensor]

Shards embedding bags row-wise, i.e.. a given embedding table is evenly distributed by rows and table slices are placed on all ranks.

create_input_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_types.SparseFeatures]
create_lookup(device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor] = None) torchrec.distributed.embedding_types.BaseEmbeddingLookup
create_output_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseEmbeddingDist[torch.Tensor]
class torchrec.distributed.sharding.rw_sharding.RwSparseFeaturesDist(pg: torch._C._distributed_c10d.ProcessGroup, num_id_list_features: int, num_id_score_list_features: int, id_list_feature_hash_sizes: List[int], id_score_list_feature_hash_sizes: List[int], device: Optional[torch.device] = None, is_sequence: bool = False, has_feature_processor: bool = False)

Bases: torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_types.SparseFeatures]

Bucketizes sparse features in RW fashion and then redistributes with an AlltoAll collective operation.

Parameters
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • intra_pg (dist.ProcessGroup) – ProcessGroup within single host group for AlltoAll

  • communication.

  • num_id_list_features (int) – total number of id list features.

  • num_id_score_list_features (int) – total number of id score list features

  • id_list_feature_hash_sizes (List[int]) – hash sizes of id list features.

  • id_score_list_feature_hash_sizes (List[int]) – hash sizes of id score list features.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • is_sequence (bool) – if this is for a sequence embedding.

  • has_feature_processor (bool) – existence of feature processor (ie. position

  • features). (weighted) –

forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torchrec.distributed.types.Awaitable[torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeatures]]

Bucketizes sparse feature values into world size number of buckets, and then performs AlltoAll operation.

Parameters

sparse_features (SparseFeatures) – sparse features to bucketize and redistribute.

Returns

awaitable of SparseFeatures.

Return type

Awaitable[SparseFeatures]

training: bool

torchrec.distributed.sharding.tw_sharding

class torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None)

Bases: torchrec.distributed.embedding_sharding.EmbeddingSharding[torchrec.distributed.sharding.tw_sharding.F, torchrec.distributed.sharding.tw_sharding.T]

base class for table-wise sharding

embedding_dims() List[int]
embedding_names() List[str]
embedding_shard_metadata() List[Optional[torch.distributed._shard.metadata.ShardMetadata]]
id_list_feature_names() List[str]
id_score_list_feature_names() List[str]
class torchrec.distributed.sharding.tw_sharding.InferTwEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None)

Bases: torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding[torchrec.distributed.embedding_types.SparseFeaturesList, List[torch.Tensor]]

Shards embedding bags table-wise for inference

create_input_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_types.SparseFeaturesList]
create_lookup(device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor] = None) torchrec.distributed.embedding_types.BaseEmbeddingLookup[torchrec.distributed.embedding_types.SparseFeaturesList, List[torch.Tensor]]
create_output_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseEmbeddingDist[List[torch.Tensor]]
class torchrec.distributed.sharding.tw_sharding.InferTwPooledEmbeddingDist(device: torch.device, world_size: int)

Bases: torchrec.distributed.embedding_sharding.BaseEmbeddingDist[List[torch.Tensor]]

Merges pooled embedding tensor from each device for inference.

Parameters
  • device (Optional[torch.device]) – device on which buffer will be allocated.

  • world_size (int) – number of devices in the topology.

forward(local_embs: List[torch.Tensor]) torchrec.distributed.types.Awaitable[torch.Tensor]

Performs AlltoOne operation on pooled embedding tensors.

Call Args:

local_embs (List[torch.Tensor]): pooled embedding tensors with len(local_embs) == world_size.

Returns

awaitable of merged pooled embedding tensor.

Return type

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.tw_sharding.InferTwSparseFeaturesDist(id_list_features_per_rank: List[int], id_score_list_features_per_rank: List[int], world_size: int)

Bases: torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_types.SparseFeaturesList]

Redistributes sparse features to all devices for inference.

Parameters
  • id_list_features_per_rank (List[int]) – number of id list features to send

  • rank. (to send to each) –

  • id_score_list_features_per_rank (List[int]) – number of id score list features

  • rank.

  • world_size (int) – number of devices in the topology.

forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torchrec.distributed.types.Awaitable[torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeaturesList]]

Performs OnetoAll operation on sparse features.

Call Args:

sparse_features (SparseFeatures): sparse features to redistribute.

Returns

awaitable of awaitable of SparseFeatures.

Return type

Awaitable[Awaitable[SparseFeatures]]

training: bool
class torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingDist(pg: torch._C._distributed_c10d.ProcessGroup, dim_sum_per_rank: List[int], device: Optional[torch.device] = None, callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None)

Bases: torchrec.distributed.embedding_sharding.BaseEmbeddingDist[torch.Tensor]

Redistributes pooled embedding tensor in TW fashion with an AlltoAll collective operation.

Parameters
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • dim_sum_per_rank (List[int]) – number of features (sum of dimensions) of the

  • rank. (embedding in each) –

  • device (Optional[torch.device]) – device on which buffers will be allocated.

forward(local_embs: torch.Tensor) torchrec.distributed.types.Awaitable[torch.Tensor]

Performs AlltoAll operation on pooled embeddings tensor.

Call Args:

local_embs (torch.Tensor): tensor of values to distribute.

Returns

awaitable of pooled embeddings.

Return type

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None)

Bases: torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding[torchrec.distributed.embedding_types.SparseFeatures, torch.Tensor]

Shards embedding bags table-wise, i.e.. a given embedding table is entirely placed on a selected rank.

create_input_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_types.SparseFeatures]
create_lookup(device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor] = None) torchrec.distributed.embedding_types.BaseEmbeddingLookup
create_output_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseEmbeddingDist[torch.Tensor]
class torchrec.distributed.sharding.tw_sharding.TwSparseFeaturesDist(pg: torch._C._distributed_c10d.ProcessGroup, id_list_features_per_rank: List[int], id_score_list_features_per_rank: List[int], device: Optional[torch.device] = None)

Bases: torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_types.SparseFeatures]

Redistributes sparse features in TW fashion with an AlltoAll collective operation.

Parameters
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • id_list_features_per_rank (List[int]) – number of id list features to send to

  • rank. (each) –

  • id_score_list_features_per_rank (List[int]) – number of id score list features to

  • rank (send to each) –

  • device (Optional[torch.device]) – device on which buffers will be allocated.

forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torchrec.distributed.types.Awaitable[torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeatures]]

Performs AlltoAll operation on sparse features.

Call Args:

sparse_features (SparseFeatures): sparse features to redistribute.

Returns

awaitable of awaitable of

SparseFeatures.

Return type

Awaitable[Awaitable[SparseFeatures]]

training: bool

torchrec.distributed.sharding.twcw_sharding

class torchrec.distributed.sharding.twcw_sharding.TwCwPooledEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None, permute_embeddings: bool = False)

Bases: torchrec.distributed.sharding.cw_sharding.CwPooledEmbeddingSharding

Shards embedding bags table-wise column-wise, i.e.. a given embedding table is distributed by specified number of columns and table slices are placed on all ranks within a host group.

torchrec.distributed.sharding.twrw_sharding

class torchrec.distributed.sharding.twrw_sharding.BaseTwRwEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None)

Bases: torchrec.distributed.embedding_sharding.EmbeddingSharding[torchrec.distributed.sharding.twrw_sharding.F, torchrec.distributed.sharding.twrw_sharding.T]

base class for table-wise-row-wise sharding

embedding_dims() List[int]
embedding_names() List[str]
embedding_shard_metadata() List[Optional[torch.distributed._shard.metadata.ShardMetadata]]
id_list_feature_names() List[str]
id_score_list_feature_names() List[str]
class torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingDist(cross_pg: torch._C._distributed_c10d.ProcessGroup, intra_pg: torch._C._distributed_c10d.ProcessGroup, dim_sum_per_node: List[int], device: Optional[torch.device] = None)

Bases: torchrec.distributed.embedding_sharding.BaseEmbeddingDist[torch.Tensor]

Redistributes pooled embedding tensor in TWRW fashion by performing a reduce-scatter operation row wise on the host level and then an AlltoAll operation table wise on the global level.

Parameters
  • cross_pg (dist.ProcessGroup) – global level ProcessGroup for AlltoAll communication.

  • intra_pg (dist.ProcessGroup) – host level ProcessGroup for reduce-scatter communication.

  • dim_sum_per_node (List[int]) – number of features (sum of dimensions) of the embedding for each host.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

forward(local_embs: torch.Tensor) torchrec.distributed.types.Awaitable[torch.Tensor]

Performs reduce-scatter pooled operation on pooled embeddings tensor followed by AlltoAll pooled operation.

Call Args:

local_embs (torch.Tensor): pooled embeddings tensor to distribute.

Returns

awaitable of pooled embeddings tensor.

Return type

Awaitable[torch.Tensor]

training: bool
class torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingSharding(embedding_configs: List[Tuple[torchrec.modules.embedding_configs.EmbeddingTableConfig, torchrec.distributed.types.ParameterSharding, torch.Tensor]], env: torchrec.distributed.types.ShardingEnv, device: Optional[torch.device] = None)

Bases: torchrec.distributed.sharding.twrw_sharding.BaseTwRwEmbeddingSharding[torchrec.distributed.embedding_types.SparseFeatures, torch.Tensor]

Shards embedding bags table-wise then row-wise.

create_input_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_types.SparseFeatures]
create_lookup(device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, feature_processor: Optional[torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor] = None) torchrec.distributed.embedding_types.BaseEmbeddingLookup
create_output_dist(device: Optional[torch.device] = None) torchrec.distributed.embedding_sharding.BaseEmbeddingDist[torch.Tensor]
class torchrec.distributed.sharding.twrw_sharding.TwRwSparseFeaturesDist(pg: torch._C._distributed_c10d.ProcessGroup, intra_pg: torch._C._distributed_c10d.ProcessGroup, id_list_features_per_rank: List[int], id_score_list_features_per_rank: List[int], id_list_feature_hash_sizes: List[int], id_score_list_feature_hash_sizes: List[int], device: Optional[torch.device] = None, has_feature_processor: bool = False)

Bases: torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist[torchrec.distributed.embedding_types.SparseFeatures]

Bucketizes sparse features in TWRW fashion and then redistributes with an AlltoAll collective operation.

Parameters
  • pg (dist.ProcessGroup) – ProcessGroup for AlltoAll communication.

  • intra_pg (dist.ProcessGroup) – ProcessGroup within single host group for AlltoAll

  • communication.

  • id_list_features_per_rank (List[int]) – number of id list features to send to

  • rank. (each) –

  • id_score_list_features_per_rank (List[int]) – number of id score list features to

  • rank (send to each) –

  • id_list_feature_hash_sizes (List[int]) – hash sizes of id list features.

  • id_score_list_feature_hash_sizes (List[int]) – hash sizes of id score list features.

  • device (Optional[torch.device]) – device on which buffers will be allocated.

  • has_feature_processor (bool) – existence of feature processor (ie. position

  • features). (weighted) –

Example:

3 features
2 hosts with 2 devices each

Bucketize each feature into 2 buckets
Staggered shuffle with feature splits [2, 1]
AlltoAll operation

NOTE: result of staggered shuffle and AlltoAll operation look the same after
reordering in AlltoAll

Result:
    host 0 device 0:
        feature 0 bucket 0
        feature 1 bucket 0

    host 0 device 1:
        feature 0 bucket 1
        feature 1 bucket 1

    host 1 device 0:
        feature 2 bucket 0

    host 1 device 1:
        feature 2 bucket 1
forward(sparse_features: torchrec.distributed.embedding_types.SparseFeatures) torchrec.distributed.types.Awaitable[torchrec.distributed.types.Awaitable[torchrec.distributed.embedding_types.SparseFeatures]]

Bucketizes sparse feature values into local world size number of buckets, performs staggered shuffle on the sparse features, and then performs AlltoAll operation.

Call Args:
sparse_features (SparseFeatures): sparse features to bucketize and

redistribute.

Returns

awaitable of SparseFeatures.

Return type

Awaitable[SparseFeatures]

training: bool

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources