torchrec.distributed.planner¶
Torchrec Planner
The planner provides the specifications necessary for a module to be sharded, considering the possible options to build an optimized plan.
- The features includes:
generating all possible sharding options.
estimating perf and storage for every shard.
estimating peak memory usage to eliminate sharding plans that might OOM.
customizability for parameter constraints, partitioning, proposers, or performance modeling.
automatically building and selecting an optimized sharding plan.
torchrec.distributed.planner.constants¶
- torchrec.distributed.planner.constants.kernel_bw_lookup(compute_device: str, compute_kernel: str, hbm_mem_bw: float, ddr_mem_bw: float, caching_ratio: Optional[float] = None) Optional[float] ¶
Calculates the device bandwidth based on given compute device, compute kernel, and caching ratio.
- Parameters:
compute_kernel (str) – compute kernel.
compute_device (str) – compute device.
hbm_mem_bw (float) – the bandwidth of the device HBM.
ddr_mem_bw (float) – the bandwidth of the system DDR memory.
caching_ratio (Optional[float]) – caching ratio used to determine device bandwidth if UVM caching is enabled.
- Returns:
the device bandwidth.
- Return type:
Optional[float]
torchrec.distributed.planner.enumerators¶
- class torchrec.distributed.planner.enumerators.EmbeddingEnumerator(topology: Topology, batch_size: int, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None)¶
Bases:
Enumerator
Generates embedding sharding options for given nn.Module, considering user provided constraints.
- Parameters:
topology (Topology) – device topology.
batch_size (int) – batch size.
constraints (Optional[Dict[str, ParameterConstraints]]) – dict of parameter names to provided ParameterConstraints.
- enumerate(module: Module, sharders: List[ModuleSharder[Module]]) List[ShardingOption] ¶
Generates relevant sharding options given module and sharders.
- Parameters:
module (nn.Module) – module to be sharded.
sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.
- Returns:
valid sharding options with values populated.
- Return type:
List[ShardingOption]
- populate_estimates(sharding_options: List[ShardingOption]) None ¶
- torchrec.distributed.planner.enumerators.get_partition_by_type(sharding_type: str) str ¶
Gets corresponding partition by type for provided sharding type.
- Parameters:
sharding_type (str) – sharding type string.
- Returns:
the corresponding PartitionByType value.
- Return type:
str
torchrec.distributed.planner.partitioners¶
- class torchrec.distributed.planner.partitioners.GreedyPerfPartitioner¶
Bases:
Partitioner
Greedy Partitioner
- partition(proposal: List[ShardingOption], storage_constraint: Topology) List[ShardingOption] ¶
Places sharding options on topology based on each sharding option’s partition_by attribute. The topology, storage, and perfs are updated at the end of the placement.
- Parameters:
proposal (List[ShardingOption]) – list of populated sharding options.
storage_constraint (Topology) – device topology.
- Returns:
list of sharding options for selected plan.
- Return type:
List[ShardingOption]
Example:
sharding_options = [ ShardingOption(partition_by="uniform", shards=[ Shards(storage=1, perf=1), Shards(storage=1, perf=1), ]), ShardingOption(partition_by="uniform", shards=[ Shards(storage=2, perf=2), Shards(storage=2, perf=2), ]), ShardingOption(partition_by="device", shards=[ Shards(storage=3, perf=3), Shards(storage=3, perf=3), ]) ShardingOption(partition_by="device", shards=[ Shards(storage=4, perf=4), Shards(storage=4, perf=4), ]), ] topology = Topology(world_size=2) # First [sharding_options[0] and sharding_options[1]] will be placed on the # topology with the uniform strategy, resulting in topology.devices[0].perf.total = (1,2) topology.devices[1].perf.total = (1,2) # Finally sharding_options[2] and sharding_options[3]] will be placed on the # topology with the device strategy (see docstring of `partition_by_device` for # more details). topology.devices[0].perf.total = (1,2) + (3,4) topology.devices[1].perf.total = (1,2) + (3,4) # The topology updates are done after the end of all the placements (the other # in the example is just for clarity).
- class torchrec.distributed.planner.partitioners.MemoryBalancedPartitioner(max_search_count: int = 10, tolerance: float = 0.02)¶
Bases:
Partitioner
Memory balanced Partitioner.
- partition(proposal: List[ShardingOption], storage_constraint: Topology) List[ShardingOption] ¶
Repeatedly calls the GreedyPerfPartitioner to find a plan with perf within the tolerance of the original plan that uses the least amount of memory.
- class torchrec.distributed.planner.partitioners.ShardingOptionGroup(sharding_options: List[torchrec.distributed.planner.types.ShardingOption], storage_sum: torchrec.distributed.planner.types.Storage)¶
Bases:
object
- sharding_options: List[ShardingOption]¶
torchrec.distributed.planner.perf_models¶
- class torchrec.distributed.planner.perf_models.NoopPerfModel(topology: Topology)¶
Bases:
PerfModel
- rate(plan: List[ShardingOption]) float ¶
torchrec.distributed.planner.planners¶
- class torchrec.distributed.planner.planners.EmbeddingShardingPlanner(topology: Optional[Topology] = None, batch_size: Optional[int] = None, enumerator: Optional[Enumerator] = None, storage_reservation: Optional[StorageReservation] = None, proposer: Optional[Union[Proposer, List[Proposer]]] = None, partitioner: Optional[Partitioner] = None, performance_model: Optional[PerfModel] = None, stats: Optional[Union[Stats, List[Stats]]] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, debug: bool = True)¶
Bases:
ShardingPlanner
Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints.
- collective_plan(module: Module, sharders: Optional[List[ModuleSharder[Module]]] = None, pg: Optional[ProcessGroup] = None) ShardingPlan ¶
Call self.plan(…) on rank 0 and broadcast
- plan(module: Module, sharders: List[ModuleSharder[Module]]) ShardingPlan ¶
Plans sharding for provided module and given sharders.
- Parameters:
module (nn.Module) – module that sharding is planned for.
sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.
- Returns:
the computed sharding plan.
- Return type:
torchrec.distributed.planner.proposers¶
- class torchrec.distributed.planner.proposers.GreedyProposer(use_depth: bool = True, threshold: Optional[int] = None)¶
Bases:
Proposer
Proposes sharding plans in greedy fashion.
Sorts sharding options for each shardable parameter by perf. On each iteration, finds parameter with largest current storage usage and tries its next sharding option.
- Parameters:
use_depth (bool) – When enabled, sharding_options of a fqn are sorted based on max(shard.perf.total), otherwise sharding_options are sorted by sum(shard.perf.total).
threshold (Optional[int]) – Threshold for early stopping. When specified, the proposer stops proposing when the proposals have consecutive worse perf_rating than best_perf_rating.
- feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None) None ¶
- load(search_space: List[ShardingOption]) None ¶
- propose() Optional[List[ShardingOption]] ¶
- class torchrec.distributed.planner.proposers.GridSearchProposer(max_proposals: int = 10000)¶
Bases:
Proposer
- feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None) None ¶
- load(search_space: List[ShardingOption]) None ¶
- propose() Optional[List[ShardingOption]] ¶
- class torchrec.distributed.planner.proposers.UniformProposer(use_depth: bool = True)¶
Bases:
Proposer
Proposes uniform sharding plans, plans that have the same sharding type for all sharding options.
- feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None) None ¶
- load(search_space: List[ShardingOption]) None ¶
- propose() Optional[List[ShardingOption]] ¶
- torchrec.distributed.planner.proposers.proposers_to_proposals_list(proposers_list: List[Proposer], search_space: List[ShardingOption]) List[List[ShardingOption]] ¶
only works for static_feedback proposers (the path of proposals to check is independent of the performance of the proposals)
torchrec.distributed.planner.shard_estimators¶
- class torchrec.distributed.planner.shard_estimators.EmbeddingPerfEstimator(topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None, is_inference: bool = False)¶
Bases:
ShardEstimator
Embedding Wall Time Perf Estimator
- estimate(sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[Module]]] = None) None ¶
- class torchrec.distributed.planner.shard_estimators.EmbeddingStorageEstimator(topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None)¶
Bases:
ShardEstimator
Embedding Storage Usage Estimator
- estimate(sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[Module]]] = None) None ¶
- torchrec.distributed.planner.shard_estimators.calculate_shard_storages(sharder: ModuleSharder[Module], sharding_type: str, tensor: Tensor, compute_device: str, compute_kernel: str, shard_sizes: List[List[int]], batch_sizes: List[int], world_size: int, local_world_size: int, input_lengths: List[float], num_poolings: List[float], caching_ratio: float, is_pooled: bool) List[Storage] ¶
Calculates estimated storage sizes for each sharded tensor, comprised of input, output, tensor, gradient, and optimizer sizes.
- Parameters:
sharder (ModuleSharder[nn.Module]) – sharder for module that supports sharding.
sharding_type (str) – provided ShardingType value.
tensor (torch.Tensor) – tensor to be sharded.
compute_device (str) – compute device to be used.
compute_kernel (str) – compute kernel to be used.
shard_sizes (List[List[int]]) – list of dimensions of each sharded tensor.
batch_sizes (List[int]) – batch size for each input feature.
world_size (int) – total number of devices in topology.
local_world_size (int) – total number of devices in host group topology.
input_lengths (List[float]) – average input lengths synonymous with pooling factors.
num_poolings (List[float]) – average number of poolings per sample (typically 1.0).
caching_ratio (float) – ratio of HBM to DDR memory for UVM caching.
is_pooled (bool) – True if embedding output is pooled (ie. EmbeddingBag), False if unpooled/sequential (ie. Embedding).
- Returns:
storage object for each device in topology.
- Return type:
List[Storage]
- torchrec.distributed.planner.shard_estimators.perf_func_emb_wall_time(shard_sizes: List[List[int]], compute_kernel: str, compute_device: str, sharding_type: str, batch_sizes: List[int], world_size: int, local_world_size: int, input_lengths: List[float], input_data_type_size: float, table_data_type_size: float, fwd_a2a_comm_data_type_size: float, bwd_a2a_comm_data_type_size: float, fwd_sr_comm_data_type_size: float, bwd_sr_comm_data_type_size: float, num_poolings: List[float], hbm_mem_bw: float, ddr_mem_bw: float, intra_host_bw: float, inter_host_bw: float, is_pooled: bool, is_weighted: bool = False, caching_ratio: Optional[float] = None, is_inference: bool = False) List[Perf] ¶
Attempts to model perfs as a function of relative wall times.
- Parameters:
shard_sizes (List[List[int]]) – the list of (local_rows, local_cols) of each shard.
compute_kernel (str) – compute kernel.
compute_device (str) – compute device.
sharding_type (str) – tw, rw, cw, twrw, dp.
batch_sizes (List[int]) – batch size for each input feature.
world_size (int) – the number of devices for all hosts.
local_world_size (int) – the number of the device for each host.
input_lengths (List[float]) – the list of the average number of lookups of each input query feature.
input_data_type_size (float) – the data type size of the distributed data_parallel input.
table_data_type_size (float) – the data type size of the table.
fwd_comm_data_type_size (float) – the data type size of the distributed data_parallel input during forward communication.
bwd_comm_data_type_size (float) – the data type size of the distributed data_parallel input during backward communication.
num_poolings (List[float]) – number of poolings per sample, typically 1.0.
hbm_mem_bw (float) – the bandwidth of the device HBM.
ddr_mem_bw (float) – the bandwidth of the system DDR memory.
intra_host_bw (float) – the bandwidth within a single host like multiple threads.
inter_host_bw (float) – the bandwidth between two hosts like multiple machines.
is_pooled (bool) – True if embedding output is pooled (ie. EmbeddingBag), False if unpooled/sequential (ie. Embedding).
is_weighted (bool = False) – if the module is an EBC and is weighted, typically signifying an id score list feature.
is_inference (bool = False) – if planning for inference.
caching_ratio (Optional[float] = None) – cache ratio to determine the bandwidth of device.
- Returns:
the list of perf for each shard.
- Return type:
List[float]
torchrec.distributed.planner.stats¶
- class torchrec.distributed.planner.stats.EmbeddingStats¶
Bases:
Stats
Stats for a sharding planner execution.
- log(sharding_plan: ShardingPlan, topology: Topology, batch_size: int, storage_reservation: StorageReservation, num_proposals: int, num_plans: int, run_time: float, best_plan: List[ShardingOption], constraints: Optional[Dict[str, ParameterConstraints]] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, debug: bool = True) None ¶
Logs stats for a given sharding plan.
Provides a tabular view of stats for the given sharding plan with per device storage usage (HBM and DDR), perf, input, output, and number/type of shards.
- Parameters:
sharding_plan (ShardingPlan) – sharding plan chosen by the planner.
topology (Topology) – device topology.
batch_size (int) – batch size.
storage_reservation (StorageReservation) – reserves storage for unsharded parts of the model
num_proposals (int) – number of proposals evaluated.
num_plans (int) – number of proposals successfully partitioned.
run_time (float) – time taken to find plan (in seconds).
best_plan (List[ShardingOption]) – plan with expected performance.
constraints (Optional[Dict[str, ParameterConstraints]]) – dict of parameter names to provided ParameterConstraints.
debug (bool) – whether to enable debug mode.
- torchrec.distributed.planner.stats.round_to_one_sigfig(x: float) str ¶
torchrec.distributed.planner.storage_reservations¶
- class torchrec.distributed.planner.storage_reservations.FixedPercentageStorageReservation(percentage: float)¶
Bases:
StorageReservation
- reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology ¶
- class torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation(percentage: float, parameter_multiplier: float = 6.0, dense_tensor_estimate: Optional[int] = None)¶
Bases:
StorageReservation
Reserves storage for model to be sharded with heuristical calculation. The storage reservation is comprised of dense tensor storage, KJT storage, and an extra percentage of total storage.
- Parameters:
percentage (float) – extra storage percent to reserve that acts as a margin of error beyond heuristic calculation of storage.
parameter_multiplier (float) – heuristic multiplier for total parameter storage.
dense_tensor_estimate (Optional[int]) – storage estimate for dense tensors, uses default heuristic estimate if not provided.
- reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology ¶
- class torchrec.distributed.planner.storage_reservations.InferenceStorageReservation(percentage: float, dense_tensor_estimate: Optional[int] = None)¶
Bases:
StorageReservation
Reserves storage for model to be sharded for inference. The storage reservation is comprised of dense tensor storage, KJT storage, and an extra percentage of total storage. Note that when estimating for storage, dense modules are assumed to be on GPUs and replicated across ranks. If this is not the case, please override the estimates with dense_tensor_estimate.
- Parameters:
percentage (float) – extra storage percentage to reserve that acts as a margin of error beyond storage calculation.
dense_tensor_estimate (Optional[int]) – storage estimate for dense tensors, use default heuristic estimate if not provided.
- reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology ¶
torchrec.distributed.planner.types¶
- class torchrec.distributed.planner.types.DeviceHardware(rank: int, storage: Storage, perf: Perf)¶
Bases:
object
Representation of a device in a process group. ‘perf’ is an estimation of network, CPU, and storage usages.
- rank: int¶
- class torchrec.distributed.planner.types.Enumerator(topology: Topology, batch_size: int = 512, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None)¶
Bases:
ABC
Generates all relevant sharding options for given topology, constraints, nn.Module, and sharders.
- abstract enumerate(module: Module, sharders: List[ModuleSharder[Module]]) List[ShardingOption] ¶
See class description.
- class torchrec.distributed.planner.types.ParameterConstraints(sharding_types: ~typing.Optional[~typing.List[str]] = None, compute_kernels: ~typing.Optional[~typing.List[str]] = None, min_partition: ~typing.Optional[int] = None, pooling_factors: ~typing.List[float] = <factory>, num_poolings: ~typing.Optional[~typing.List[float]] = None, batch_sizes: ~typing.Optional[~typing.List[int]] = None, is_weighted: bool = False, cache_params: ~typing.Optional[~torchrec.distributed.types.CacheParams] = None, enforce_hbm: ~typing.Optional[bool] = None, stochastic_rounding: ~typing.Optional[bool] = None, bounds_check_mode: ~typing.Optional[~torchrec.distributed.types.BoundsCheckMode] = None)¶
Bases:
object
Stores user provided constraints around the sharding plan.
If provided, pooling_factors, num_poolings, and batch_sizes must match in length, as per sample.
- batch_sizes: Optional[List[int]] = None¶
- bounds_check_mode: Optional[BoundsCheckMode] = None¶
- cache_params: Optional[CacheParams] = None¶
- compute_kernels: Optional[List[str]] = None¶
- enforce_hbm: Optional[bool] = None¶
- is_weighted: bool = False¶
- min_partition: Optional[int] = None¶
- num_poolings: Optional[List[float]] = None¶
- pooling_factors: List[float]¶
- sharding_types: Optional[List[str]] = None¶
- stochastic_rounding: Optional[bool] = None¶
- class torchrec.distributed.planner.types.PartitionByType(value)¶
Bases:
Enum
Well-known partition types.
- DEVICE = 'device'¶
- HOST = 'host'¶
- UNIFORM = 'uniform'¶
- class torchrec.distributed.planner.types.Partitioner¶
Bases:
ABC
Partitions shards.
Today we have multiple strategies ie. (Greedy, BLDM, Linear).
- abstract partition(proposal: List[ShardingOption], storage_constraint: Topology) List[ShardingOption] ¶
- class torchrec.distributed.planner.types.Perf(fwd_compute: float, fwd_comms: float, bwd_compute: float, bwd_comms: float)¶
Bases:
object
Representation of the breakdown of the perf estimate a single shard of an embedding table.
- bwd_comms: float¶
- bwd_compute: float¶
- fwd_comms: float¶
- fwd_compute: float¶
- property total: float¶
- class torchrec.distributed.planner.types.PerfModel¶
Bases:
ABC
- abstract rate(plan: List[ShardingOption]) float ¶
- exception torchrec.distributed.planner.types.PlannerError(message: str, error_type: PlannerErrorType = PlannerErrorType.OTHER)¶
Bases:
Exception
- class torchrec.distributed.planner.types.PlannerErrorType(value)¶
Bases:
Enum
Classify PlannerError based on the following cases.
- INSUFFICIENT_STORAGE = 'insufficient_storage'¶
- OTHER = 'other'¶
- PARTITION = 'partition'¶
- STRICT_CONSTRAINTS = 'strict_constraints'¶
- class torchrec.distributed.planner.types.Proposer¶
Bases:
ABC
Prosposes complete lists of sharding options which can be parititioned to generate a plan.
- abstract feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None) None ¶
- abstract load(search_space: List[ShardingOption]) None ¶
- abstract propose() Optional[List[ShardingOption]] ¶
- class torchrec.distributed.planner.types.Shard(size: List[int], offset: List[int], storage: Optional[Storage] = None, perf: Optional[Perf] = None, rank: Optional[int] = None)¶
Bases:
object
Representation of a subset of an embedding table. ‘size’ and ‘offset’ fully determine the tensors in the shard. ‘storage’ is an estimation of how much it takes to store the shard with an estimation ‘perf’.
- offset: List[int]¶
- rank: Optional[int] = None¶
- size: List[int]¶
- class torchrec.distributed.planner.types.ShardEstimator(topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None)¶
Bases:
ABC
Estimates shard perf or storage, requires fully specified sharding options.
- abstract estimate(sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[Module]]] = None) None ¶
- class torchrec.distributed.planner.types.ShardingOption(name: str, tensor: Tensor, module: Tuple[str, Module], input_lengths: List[float], batch_size: int, sharding_type: str, partition_by: str, compute_kernel: str, shards: List[Shard], cache_params: Optional[CacheParams] = None, enforce_hbm: Optional[bool] = None, stochastic_rounding: Optional[bool] = None, bounds_check_mode: Optional[BoundsCheckMode] = None, dependency: Optional[str] = None)¶
Bases:
object
One way of sharding an embedding table.
- property fqn: str¶
- property is_pooled: bool¶
- property module: Tuple[str, Module]¶
- property num_inputs: int¶
- property num_shards: int¶
- property path: str¶
- property tensor: Tensor¶
- class torchrec.distributed.planner.types.Stats¶
Bases:
ABC
Logs statistics related to the sharding plan.
- abstract log(sharding_plan: ShardingPlan, topology: Topology, batch_size: int, storage_reservation: StorageReservation, num_proposals: int, num_plans: int, run_time: float, best_plan: List[ShardingOption], constraints: Optional[Dict[str, ParameterConstraints]] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, debug: bool = False) None ¶
See class description
- class torchrec.distributed.planner.types.Storage(hbm: int, ddr: int)¶
Bases:
object
Representation of the storage capacities of a hardware used in training.
- ddr: int¶
- hbm: int¶
- class torchrec.distributed.planner.types.StorageReservation¶
Bases:
ABC
Reserves storage space for non-sharded parts of the model.
- abstract reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology ¶
- class torchrec.distributed.planner.types.Topology(world_size: int, compute_device: str, hbm_cap: Optional[int] = None, ddr_cap: Optional[int] = None, local_world_size: Optional[int] = None, hbm_mem_bw: float = 963146416.128, ddr_mem_bw: float = 54760833.024, intra_host_bw: float = 644245094.4, inter_host_bw: float = 13421772.8)¶
Bases:
object
- property compute_device: str¶
- property ddr_mem_bw: float¶
- property devices: List[DeviceHardware]¶
- property hbm_mem_bw: float¶
- property inter_host_bw: float¶
- property intra_host_bw: float¶
- property local_world_size: int¶
- property world_size: int¶
torchrec.distributed.planner.utils¶
- torchrec.distributed.planner.utils.bytes_to_gb(num_bytes: int) float ¶
- torchrec.distributed.planner.utils.bytes_to_mb(num_bytes: Union[float, int]) float ¶
- torchrec.distributed.planner.utils.gb_to_bytes(gb: float) int ¶
- torchrec.distributed.planner.utils.placement(compute_device: str, rank: int, local_size: int) str ¶
Returns placement, formatted as string
- torchrec.distributed.planner.utils.prod(iterable: Iterable[int]) int ¶
- torchrec.distributed.planner.utils.reset_shard_rank(proposal: List[ShardingOption]) None ¶
- torchrec.distributed.planner.utils.sharder_name(t: Type[Any]) str ¶