Shortcuts

Planner

The TorchRec Planner is responsible for determining the most performant, balanced sharding plan for distributed training and inference.

The main API for generating a sharding plan is EmbeddingShardingPlanner.plan

class torchrec.distributed.types.ShardingPlan(plan: Dict[str, ModuleShardingPlan])

Representation of sharding plan. This uses the FQN of the larger wrapped model (i.e the model that is wrapped using DistributedModelParallel) EmbeddingModuleShardingPlan should be used when TorchRec composability is desired.

plan

dict keyed by module path of dict of parameter sharding specs keyed by parameter name.

Type:

Dict[str, EmbeddingModuleShardingPlan]

get_plan_for_module(module_path: str) Optional[ModuleShardingPlan]
Parameters:

module_path (str) –

Returns:

dict of parameter sharding specs keyed by parameter name. None if sharding specs do not exist for given module_path.

Return type:

Optional[ModuleShardingPlan]

class torchrec.distributed.planner.planners.EmbeddingShardingPlanner(topology: Optional[Topology] = None, batch_size: Optional[int] = None, enumerator: Optional[Enumerator] = None, storage_reservation: Optional[StorageReservation] = None, proposer: Optional[Union[Proposer, List[Proposer]]] = None, partitioner: Optional[Partitioner] = None, performance_model: Optional[PerfModel] = None, stats: Optional[Union[Stats, List[Stats]]] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, debug: bool = True, callbacks: Optional[List[Callable[[List[ShardingOption]], List[ShardingOption]]]] = None)

Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints.

Parameters:
  • topology (Optional[Topology]) – the topology of the current process group.

  • batch_size (Optional[int]) – the batch size of the model.

  • enumerator (Optional[Enumerator]) – the enumerator to use

  • storage_reservation (Optional[StorageReservation]) – the storage reservation to use

  • proposer (Optional[Union[Proposer, List[Proposer]]]) – the proposer(s) to use

  • partitioner (Optional[Partitioner]) – the partitioner to use

  • performance_model (Optional[PerfModel]) – the performance model to use

  • stats (Optional[Union[Stats, List[Stats]]]) – the stats to use

  • constraints (Optional[Dict[str, ParameterConstraints]]) – per table constraints for sharding.

  • debug (bool) – whether to print debug information.

Example:

ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))
planner = EmbeddingShardingPlanner()
plan = planner.plan(
    module=ebc,
    sharders=[EmbeddingBagCollectionSharder()],
)
collective_plan(module: Module, sharders: Optional[List[ModuleSharder[Module]]] = None, pg: Optional[ProcessGroup] = None) ShardingPlan

Call self.plan(…) on rank 0 and broadcast

Parameters:
  • module (nn.Module) – the module to shard.

  • sharders (Optional[List[ModuleSharder[nn.Module]]]) – the sharders to use for sharding

  • pg (Optional[dist.ProcessGroup]) – the process group to use for collective operations

Returns:

the sharding plan for the module.

Return type:

ShardingPlan

plan(module: Module, sharders: List[ModuleSharder[Module]]) ShardingPlan

Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints.

Parameters:
  • module (nn.Module) – the module to shard.

  • sharders (List[ModuleSharder[nn.Module]]) – the sharders to use for sharding.

Returns:

the sharding plan for the module.

Return type:

ShardingPlan

class torchrec.distributed.planner.enumerators.EmbeddingEnumerator(topology: Topology, batch_size: int, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None, use_exact_enumerate_order: Optional[bool] = False)

Generates embedding sharding options for given nn.Module, considering user provided constraints.

Parameters:
  • topology (Topology) – device topology.

  • batch_size (int) – batch size.

  • constraints (Optional[Dict[str, ParameterConstraints]]) – dict of parameter names to provided ParameterConstraints.

  • estimator (Optional[Union[ShardEstimator, List[ShardEstimator]]]) – shard performance estimators.

  • use_exact_enumerate_order (bool) – whether to enumerate shardable parameters in the exact name_children enumeration order

enumerate(module: Module, sharders: List[ModuleSharder[Module]]) List[ShardingOption]

Generates relevant sharding options given module and sharders.

Parameters:
  • module (nn.Module) – module to be sharded.

  • sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.

Returns:

valid sharding options with values populated.

Return type:

List[ShardingOption]

populate_estimates(sharding_options: List[ShardingOption]) None

See class description.

class torchrec.distributed.planner.partitioners.GreedyPerfPartitioner(sort_by: SortBy = SortBy.STORAGE, balance_modules: bool = False)

Greedy Partitioner.

Parameters:
  • sort_by (SortBy) – Sort sharding options by storage or perf in descending order (i.e., large tables will be placed first).

  • balance_modules (bool) – Whether to sort by modules first, where smaller modules will be sorted first. In effect, this will place tables in each module in a balanced way.

partition(proposal: List[ShardingOption], storage_constraint: Topology) List[ShardingOption]

Places sharding options on topology based on each sharding option’s partition_by attribute. The topology, storage, and perfs are updated at the end of the placement.

Parameters:
  • proposal (List[ShardingOption]) – list of populated sharding options.

  • storage_constraint (Topology) – device topology.

Returns:

list of sharding options for selected plan.

Return type:

List[ShardingOption]

Example:

sharding_options = [
        ShardingOption(partition_by="uniform",
                shards=[
                    Shards(storage=1, perf=1),
                    Shards(storage=1, perf=1),
                ]),
        ShardingOption(partition_by="uniform",
                shards=[
                    Shards(storage=2, perf=2),
                    Shards(storage=2, perf=2),
                ]),
        ShardingOption(partition_by="device",
                shards=[
                    Shards(storage=3, perf=3),
                    Shards(storage=3, perf=3),
                ])
        ShardingOption(partition_by="device",
                shards=[
                    Shards(storage=4, perf=4),
                    Shards(storage=4, perf=4),
                ]),
    ]
topology = Topology(world_size=2)

# First [sharding_options[0] and sharding_options[1]] will be placed on the
# topology with the uniform strategy, resulting in

topology.devices[0].perf.total = (1,2)
topology.devices[1].perf.total = (1,2)

# Finally sharding_options[2] and sharding_options[3]] will be placed on the
# topology with the device strategy (see docstring of `partition_by_device` for
# more details).

topology.devices[0].perf.total = (1,2) + (3,4)
topology.devices[1].perf.total = (1,2) + (3,4)

# The topology updates are done after the end of all the placements (the other
# in the example is just for clarity).
class torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation(percentage: float, parameter_multiplier: float = 6.0, dense_tensor_estimate: Optional[int] = None)

Reserves storage for model to be sharded with heuristical calculation. The storage reservation is comprised of dense tensor storage, KJT storage, and an extra percentage of total storage.

Parameters:
  • percentage (float) – extra storage percent to reserve that acts as a margin of error beyond heuristic calculation of storage.

  • parameter_multiplier (float) – heuristic multiplier for total parameter storage.

  • dense_tensor_estimate (Optional[int]) – storage estimate for dense tensors, uses default heuristic estimate if not provided.

class torchrec.distributed.planner.proposers.GreedyProposer(use_depth: bool = True, threshold: Optional[int] = None)

Proposes sharding plans in greedy fashion.

Sorts sharding options for each shardable parameter by perf. On each iteration, finds parameter with largest current storage usage and tries its next sharding option.

Parameters:
  • use_depth (bool) – When enabled, sharding_options of a fqn are sorted based on max(shard.perf.total), otherwise sharding_options are sorted by sum(shard.perf.total).

  • threshold (Optional[int]) – Threshold for early stopping. When specified, the proposer stops proposing when the proposals have consecutive worse perf_rating than best_perf_rating.

feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None) None

Provide feedback to proposer.

Parameters:
  • partitionable (bool) – whether the plan is partitionable.

  • plan (Optional[List[ShardingOption]]) – plan to provide feedback on.

  • perf_rating (Optional[float]) – performance rating of the plan.

  • storage_constraint (Optional[Topology]) – storage constraint of the plan.

load(search_space: List[ShardingOption], enumerator: Optional[Enumerator] = None) None

Load search space into proposer.

Parameters:
  • search_space (List[ShardingOption]) – search space to load.

  • enumerator (Enumerator) – enumerator used to generate search space.

propose() Optional[List[ShardingOption]]

Propose a sharding plan.

Returns:

proposed plan.

Return type:

Optional[List[ShardingOption]]

class torchrec.distributed.planner.shard_estimators.EmbeddingPerfEstimator(topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None, is_inference: bool = False)

Embedding Wall Time Perf Estimator. This estimator estimates the wall time of a given sharding option.

Parameters:
  • topology (Topology) – device topology.

  • constraints (Optional[Dict[str, ParameterConstraints]]) – parameter constraints.

  • is_inference (bool) – whether or not the estimator is used for inference.

estimate(sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[Module]]] = None) None

Estimates the wall time of a given sharding option.

Parameters:
  • sharding_options (List[ShardingOption]) – list of sharding options.

  • sharder_map (Optional[Dict[str, ModuleSharder[nn.Module]]]) – sharder map.

classmethod perf_func_emb_wall_time(shard_sizes: List[List[int]], compute_kernel: str, compute_device: str, sharding_type: str, batch_sizes: List[int], world_size: int, local_world_size: int, input_lengths: List[float], input_data_type_size: float, table_data_type_size: float, output_data_type_size: float, fwd_a2a_comm_data_type_size: float, bwd_a2a_comm_data_type_size: float, fwd_sr_comm_data_type_size: float, bwd_sr_comm_data_type_size: float, num_poolings: List[float], hbm_mem_bw: float, ddr_mem_bw: float, hbm_to_ddr_mem_bw: float, intra_host_bw: float, inter_host_bw: float, bwd_compute_multiplier: float, weighted_feature_bwd_compute_multiplier: float, is_pooled: bool, is_weighted: bool = False, caching_ratio: Optional[float] = None, is_inference: bool = False, prefetch_pipeline: bool = False, expected_cache_fetches: float = 0, uneven_sharding_perf_multiplier: float = 1.0) List[Perf]

Attempts to model perfs as a function of relative wall times.

Parameters:
  • shard_sizes (List[List[int]]) – the list of (local_rows, local_cols) of each shard.

  • compute_kernel (str) – compute kernel.

  • compute_device (str) – compute device.

  • sharding_type (str) – tw, rw, cw, twrw, dp.

  • batch_sizes (List[int]) – batch size for each input feature.

  • world_size (int) – the number of devices for all hosts.

  • local_world_size (int) – the number of the device for each host.

  • input_lengths (List[float]) – the list of the average number of lookups of each input query feature.

  • input_data_type_size (float) – the data type size of the distributed data_parallel input.

  • table_data_type_size (float) – the data type size of the table.

  • output_data_type_size (float) – the data type size of the output embeddings.

  • fwd_comm_data_type_size (float) – the data type size of the distributed data_parallel input during forward communication.

  • bwd_comm_data_type_size (float) – the data type size of the distributed data_parallel input during backward communication.

  • num_poolings (List[float]) – number of poolings per sample, typically 1.0.

  • hbm_mem_bw (float) – the bandwidth of the device HBM.

  • ddr_mem_bw (float) – the bandwidth of the system DDR memory.

  • hbm_to_ddr_bw (float) – the bandwidth between device HBM and system DDR.

  • intra_host_bw (float) – the bandwidth within a single host like multiple threads.

  • inter_host_bw (float) – the bandwidth between two hosts like multiple machines.

  • is_pooled (bool) – True if embedding output is pooled (ie. EmbeddingBag), False if unpooled/sequential (ie. Embedding).

  • is_weighted (bool = False) – if the module is an EBC and is weighted, typically signifying an id score list feature.

  • is_inference (bool = False) – if planning for inference.

  • caching_ratio (Optional[float] = None) – cache ratio to determine the bandwidth of device.

  • prefetch_pipeline (bool = False) – whether prefetch pipeline is enabled.

  • expected_cache_fetches (float) – number of expected cache fetches across global batch

  • uneven_sharding_perf_multiplier (float = 1.0) – multiplier to account for uneven sharding perf

Returns:

the list of perf for each shard.

Return type:

List[float]

class torchrec.distributed.planner.shard_estimators.EmbeddingStorageEstimator(topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None, pipeline_type: PipelineType = PipelineType.NONE, run_embedding_at_peak_memory: bool = False, is_inference: bool = False)

Embedding Storage Usage Estimator

Parameters:
  • topology (Topology) – device topology.

  • constraints (Optional[Dict[str, ParameterConstraints]]) – parameter constraints.

  • pipeline_type (PipelineType) – The type of pipeline, if any. Will determine the input replication factor during memory estimation.

  • run_embedding_at_peak_memory (bool) –

    If the embedding fwd/bwd will be execute when HBM usage is at peak. When set to TRUE, any temporary memory allocation during embedding forward/backward, as long as output sizes before output_dist will be counted towards HBM storage cost. Otherwise they won’t since they’ll be “hidden” by the real memory peak.

    Only take effect if pipeline_type is set for backward compatibility (not affecting models using old pipeline-agnostic formula)

    Default to false because this is typically false for RecSys since memory peak happens at the end of dense forwrad / beginning of dense backward instead.

  • is_inference (bool) – If the model is inference model. Default to False.

estimate(sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[Module]]] = None) None

Estimate the storage cost of each sharding option.

Parameters:
  • sharding_options (List[ShardingOption]) – list of sharding options.

  • sharder_map (Optional[Dict[str, ModuleSharder[nn.Module]]]) – map from module type to sharder.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources