• Docs >
  • torchrec.distributed.planner
Shortcuts

torchrec.distributed.planner

Torchrec Planner

The planner provides the specifications necessary for a module to be sharded, considering the possible options to build an optimized plan.

The features includes:
  • generating all possible sharding options.

  • estimating perf and storage for every shard.

  • estimating peak memory usage to eliminate sharding plans that might OOM.

  • customizability for parameter constraints, partitioning, proposers, or performance modeling.

  • automatically building and selecting an optimized sharding plan.

torchrec.distributed.planner.constants

torchrec.distributed.planner.constants.kernel_bw_lookup(compute_device: str, compute_kernel: str, hbm_mem_bw: float, ddr_mem_bw: float, caching_ratio: Optional[float] = None) Optional[float]

Calculates the device bandwidth based on given compute device, compute kernel, and caching ratio.

Parameters:
  • compute_kernel (str) – compute kernel.

  • compute_device (str) – compute device.

  • hbm_mem_bw (float) – the bandwidth of the device HBM.

  • ddr_mem_bw (float) – the bandwidth of the system DDR memory.

  • caching_ratio (Optional[float]) – caching ratio used to determine device bandwidth if UVM caching is enabled.

Returns:

the device bandwidth.

Return type:

Optional[float]

torchrec.distributed.planner.enumerators

class torchrec.distributed.planner.enumerators.EmbeddingEnumerator(topology: Topology, batch_size: int, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None)

Bases: Enumerator

Generates embedding sharding options for given nn.Module, considering user provided constraints.

Parameters:
  • topology (Topology) – device topology.

  • batch_size (int) – batch size.

  • constraints (Optional[Dict[str, ParameterConstraints]]) – dict of parameter names to provided ParameterConstraints.

enumerate(module: Module, sharders: List[ModuleSharder[Module]]) List[ShardingOption]

Generates relevant sharding options given module and sharders.

Parameters:
  • module (nn.Module) – module to be sharded.

  • sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.

Returns:

valid sharding options with values populated.

Return type:

List[ShardingOption]

populate_estimates(sharding_options: List[ShardingOption]) None
torchrec.distributed.planner.enumerators.get_partition_by_type(sharding_type: str) str

Gets corresponding partition by type for provided sharding type.

Parameters:

sharding_type (str) – sharding type string.

Returns:

the corresponding PartitionByType value.

Return type:

str

torchrec.distributed.planner.partitioners

class torchrec.distributed.planner.partitioners.GreedyPerfPartitioner

Bases: Partitioner

Greedy Partitioner

partition(proposal: List[ShardingOption], storage_constraint: Topology) List[ShardingOption]

Places sharding options on topology based on each sharding option’s partition_by attribute. The topology, storage, and perfs are updated at the end of the placement.

Parameters:
  • proposal (List[ShardingOption]) – list of populated sharding options.

  • storage_constraint (Topology) – device topology.

Returns:

list of sharding options for selected plan.

Return type:

List[ShardingOption]

Example:

sharding_options = [
        ShardingOption(partition_by="uniform",
                shards=[
                    Shards(storage=1, perf=1),
                    Shards(storage=1, perf=1),
                ]),
        ShardingOption(partition_by="uniform",
                shards=[
                    Shards(storage=2, perf=2),
                    Shards(storage=2, perf=2),
                ]),
        ShardingOption(partition_by="device",
                shards=[
                    Shards(storage=3, perf=3),
                    Shards(storage=3, perf=3),
                ])
        ShardingOption(partition_by="device",
                shards=[
                    Shards(storage=4, perf=4),
                    Shards(storage=4, perf=4),
                ]),
    ]
topology = Topology(world_size=2)

# First [sharding_options[0] and sharding_options[1]] will be placed on the
# topology with the uniform strategy, resulting in

topology.devices[0].perf.total = (1,2)
topology.devices[1].perf.total = (1,2)

# Finally sharding_options[2] and sharding_options[3]] will be placed on the
# topology with the device strategy (see docstring of `partition_by_device` for
# more details).

topology.devices[0].perf.total = (1,2) + (3,4)
topology.devices[1].perf.total = (1,2) + (3,4)

# The topology updates are done after the end of all the placements (the other
# in the example is just for clarity).
class torchrec.distributed.planner.partitioners.MemoryBalancedPartitioner(max_search_count: int = 10, tolerance: float = 0.02)

Bases: Partitioner

Memory balanced Partitioner.

partition(proposal: List[ShardingOption], storage_constraint: Topology) List[ShardingOption]

Repeatedly calls the GreedyPerfPartitioner to find a plan with perf within the tolerance of the original plan that uses the least amount of memory.

class torchrec.distributed.planner.partitioners.ShardingOptionGroup(sharding_options: List[torchrec.distributed.planner.types.ShardingOption], storage_sum: torchrec.distributed.planner.types.Storage)

Bases: object

sharding_options: List[ShardingOption]
storage_sum: Storage
torchrec.distributed.planner.partitioners.set_hbm_per_device(storage_constraint: Topology, hbm_per_device: int) None

torchrec.distributed.planner.perf_models

class torchrec.distributed.planner.perf_models.NoopPerfModel(topology: Topology)

Bases: PerfModel

rate(plan: List[ShardingOption]) float

torchrec.distributed.planner.planners

class torchrec.distributed.planner.planners.EmbeddingShardingPlanner(topology: Optional[Topology] = None, batch_size: Optional[int] = None, enumerator: Optional[Enumerator] = None, storage_reservation: Optional[StorageReservation] = None, proposer: Optional[Union[Proposer, List[Proposer]]] = None, partitioner: Optional[Partitioner] = None, performance_model: Optional[PerfModel] = None, stats: Optional[Union[Stats, List[Stats]]] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, debug: bool = True)

Bases: ShardingPlanner

Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints.

collective_plan(module: Module, sharders: Optional[List[ModuleSharder[Module]]] = None, pg: Optional[ProcessGroup] = None) ShardingPlan

Call self.plan(…) on rank 0 and broadcast

plan(module: Module, sharders: List[ModuleSharder[Module]]) ShardingPlan

Plans sharding for provided module and given sharders.

Parameters:
  • module (nn.Module) – module that sharding is planned for.

  • sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.

Returns:

the computed sharding plan.

Return type:

ShardingPlan

torchrec.distributed.planner.proposers

class torchrec.distributed.planner.proposers.GreedyProposer(use_depth: bool = True, threshold: Optional[int] = None)

Bases: Proposer

Proposes sharding plans in greedy fashion.

Sorts sharding options for each shardable parameter by perf. On each iteration, finds parameter with largest current storage usage and tries its next sharding option.

Parameters:
  • use_depth (bool) – When enabled, sharding_options of a fqn are sorted based on max(shard.perf.total), otherwise sharding_options are sorted by sum(shard.perf.total).

  • threshold (Optional[int]) – Threshold for early stopping. When specified, the proposer stops proposing when the proposals have consecutive worse perf_rating than best_perf_rating.

feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None) None
load(search_space: List[ShardingOption]) None
propose() Optional[List[ShardingOption]]
class torchrec.distributed.planner.proposers.GridSearchProposer(max_proposals: int = 10000)

Bases: Proposer

feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None) None
load(search_space: List[ShardingOption]) None
propose() Optional[List[ShardingOption]]
class torchrec.distributed.planner.proposers.UniformProposer(use_depth: bool = True)

Bases: Proposer

Proposes uniform sharding plans, plans that have the same sharding type for all sharding options.

feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None) None
load(search_space: List[ShardingOption]) None
propose() Optional[List[ShardingOption]]
torchrec.distributed.planner.proposers.proposers_to_proposals_list(proposers_list: List[Proposer], search_space: List[ShardingOption]) List[List[ShardingOption]]

only works for static_feedback proposers (the path of proposals to check is independent of the performance of the proposals)

torchrec.distributed.planner.shard_estimators

class torchrec.distributed.planner.shard_estimators.EmbeddingPerfEstimator(topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None, is_inference: bool = False)

Bases: ShardEstimator

Embedding Wall Time Perf Estimator

estimate(sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[Module]]] = None) None
class torchrec.distributed.planner.shard_estimators.EmbeddingStorageEstimator(topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None)

Bases: ShardEstimator

Embedding Storage Usage Estimator

estimate(sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[Module]]] = None) None
torchrec.distributed.planner.shard_estimators.calculate_shard_storages(sharder: ModuleSharder[Module], sharding_type: str, tensor: Tensor, compute_device: str, compute_kernel: str, shard_sizes: List[List[int]], batch_sizes: List[int], world_size: int, local_world_size: int, input_lengths: List[float], num_poolings: List[float], caching_ratio: float, is_pooled: bool) List[Storage]

Calculates estimated storage sizes for each sharded tensor, comprised of input, output, tensor, gradient, and optimizer sizes.

Parameters:
  • sharder (ModuleSharder[nn.Module]) – sharder for module that supports sharding.

  • sharding_type (str) – provided ShardingType value.

  • tensor (torch.Tensor) – tensor to be sharded.

  • compute_device (str) – compute device to be used.

  • compute_kernel (str) – compute kernel to be used.

  • shard_sizes (List[List[int]]) – list of dimensions of each sharded tensor.

  • batch_sizes (List[int]) – batch size for each input feature.

  • world_size (int) – total number of devices in topology.

  • local_world_size (int) – total number of devices in host group topology.

  • input_lengths (List[float]) – average input lengths synonymous with pooling factors.

  • num_poolings (List[float]) – average number of poolings per sample (typically 1.0).

  • caching_ratio (float) – ratio of HBM to DDR memory for UVM caching.

  • is_pooled (bool) – True if embedding output is pooled (ie. EmbeddingBag), False if unpooled/sequential (ie. Embedding).

Returns:

storage object for each device in topology.

Return type:

List[Storage]

torchrec.distributed.planner.shard_estimators.perf_func_emb_wall_time(shard_sizes: List[List[int]], compute_kernel: str, compute_device: str, sharding_type: str, batch_sizes: List[int], world_size: int, local_world_size: int, input_lengths: List[float], input_data_type_size: float, table_data_type_size: float, fwd_a2a_comm_data_type_size: float, bwd_a2a_comm_data_type_size: float, fwd_sr_comm_data_type_size: float, bwd_sr_comm_data_type_size: float, num_poolings: List[float], hbm_mem_bw: float, ddr_mem_bw: float, intra_host_bw: float, inter_host_bw: float, is_pooled: bool, is_weighted: bool = False, caching_ratio: Optional[float] = None, is_inference: bool = False) List[Perf]

Attempts to model perfs as a function of relative wall times.

Parameters:
  • shard_sizes (List[List[int]]) – the list of (local_rows, local_cols) of each shard.

  • compute_kernel (str) – compute kernel.

  • compute_device (str) – compute device.

  • sharding_type (str) – tw, rw, cw, twrw, dp.

  • batch_sizes (List[int]) – batch size for each input feature.

  • world_size (int) – the number of devices for all hosts.

  • local_world_size (int) – the number of the device for each host.

  • input_lengths (List[float]) – the list of the average number of lookups of each input query feature.

  • input_data_type_size (float) – the data type size of the distributed data_parallel input.

  • table_data_type_size (float) – the data type size of the table.

  • fwd_comm_data_type_size (float) – the data type size of the distributed data_parallel input during forward communication.

  • bwd_comm_data_type_size (float) – the data type size of the distributed data_parallel input during backward communication.

  • num_poolings (List[float]) – number of poolings per sample, typically 1.0.

  • hbm_mem_bw (float) – the bandwidth of the device HBM.

  • ddr_mem_bw (float) – the bandwidth of the system DDR memory.

  • intra_host_bw (float) – the bandwidth within a single host like multiple threads.

  • inter_host_bw (float) – the bandwidth between two hosts like multiple machines.

  • is_pooled (bool) – True if embedding output is pooled (ie. EmbeddingBag), False if unpooled/sequential (ie. Embedding).

  • is_weighted (bool = False) – if the module is an EBC and is weighted, typically signifying an id score list feature.

  • is_inference (bool = False) – if planning for inference.

  • caching_ratio (Optional[float] = None) – cache ratio to determine the bandwidth of device.

Returns:

the list of perf for each shard.

Return type:

List[float]

torchrec.distributed.planner.stats

class torchrec.distributed.planner.stats.EmbeddingStats

Bases: Stats

Stats for a sharding planner execution.

log(sharding_plan: ShardingPlan, topology: Topology, batch_size: int, storage_reservation: StorageReservation, num_proposals: int, num_plans: int, run_time: float, best_plan: List[ShardingOption], constraints: Optional[Dict[str, ParameterConstraints]] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, debug: bool = True) None

Logs stats for a given sharding plan.

Provides a tabular view of stats for the given sharding plan with per device storage usage (HBM and DDR), perf, input, output, and number/type of shards.

Parameters:
  • sharding_plan (ShardingPlan) – sharding plan chosen by the planner.

  • topology (Topology) – device topology.

  • batch_size (int) – batch size.

  • storage_reservation (StorageReservation) – reserves storage for unsharded parts of the model

  • num_proposals (int) – number of proposals evaluated.

  • num_plans (int) – number of proposals successfully partitioned.

  • run_time (float) – time taken to find plan (in seconds).

  • best_plan (List[ShardingOption]) – plan with expected performance.

  • constraints (Optional[Dict[str, ParameterConstraints]]) – dict of parameter names to provided ParameterConstraints.

  • debug (bool) – whether to enable debug mode.

torchrec.distributed.planner.stats.round_to_one_sigfig(x: float) str

torchrec.distributed.planner.storage_reservations

class torchrec.distributed.planner.storage_reservations.FixedPercentageStorageReservation(percentage: float)

Bases: StorageReservation

reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology
class torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation(percentage: float, parameter_multiplier: float = 6.0, dense_tensor_estimate: Optional[int] = None)

Bases: StorageReservation

Reserves storage for model to be sharded with heuristical calculation. The storage reservation is comprised of dense tensor storage, KJT storage, and an extra percentage of total storage.

Parameters:
  • percentage (float) – extra storage percent to reserve that acts as a margin of error beyond heuristic calculation of storage.

  • parameter_multiplier (float) – heuristic multiplier for total parameter storage.

  • dense_tensor_estimate (Optional[int]) – storage estimate for dense tensors, uses default heuristic estimate if not provided.

reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology
class torchrec.distributed.planner.storage_reservations.InferenceStorageReservation(percentage: float, dense_tensor_estimate: Optional[int] = None)

Bases: StorageReservation

Reserves storage for model to be sharded for inference. The storage reservation is comprised of dense tensor storage, KJT storage, and an extra percentage of total storage. Note that when estimating for storage, dense modules are assumed to be on GPUs and replicated across ranks. If this is not the case, please override the estimates with dense_tensor_estimate.

Parameters:
  • percentage (float) – extra storage percentage to reserve that acts as a margin of error beyond storage calculation.

  • dense_tensor_estimate (Optional[int]) – storage estimate for dense tensors, use default heuristic estimate if not provided.

reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology

torchrec.distributed.planner.types

class torchrec.distributed.planner.types.DeviceHardware(rank: int, storage: Storage, perf: Perf)

Bases: object

Representation of a device in a process group. ‘perf’ is an estimation of network, CPU, and storage usages.

perf: Perf
rank: int
storage: Storage
class torchrec.distributed.planner.types.Enumerator(topology: Topology, batch_size: int = 512, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None)

Bases: ABC

Generates all relevant sharding options for given topology, constraints, nn.Module, and sharders.

abstract enumerate(module: Module, sharders: List[ModuleSharder[Module]]) List[ShardingOption]

See class description.

class torchrec.distributed.planner.types.ParameterConstraints(sharding_types: ~typing.Optional[~typing.List[str]] = None, compute_kernels: ~typing.Optional[~typing.List[str]] = None, min_partition: ~typing.Optional[int] = None, pooling_factors: ~typing.List[float] = <factory>, num_poolings: ~typing.Optional[~typing.List[float]] = None, batch_sizes: ~typing.Optional[~typing.List[int]] = None, is_weighted: bool = False, cache_params: ~typing.Optional[~torchrec.distributed.types.CacheParams] = None, enforce_hbm: ~typing.Optional[bool] = None, stochastic_rounding: ~typing.Optional[bool] = None, bounds_check_mode: ~typing.Optional[~torchrec.distributed.types.BoundsCheckMode] = None)

Bases: object

Stores user provided constraints around the sharding plan.

If provided, pooling_factors, num_poolings, and batch_sizes must match in length, as per sample.

batch_sizes: Optional[List[int]] = None
bounds_check_mode: Optional[BoundsCheckMode] = None
cache_params: Optional[CacheParams] = None
compute_kernels: Optional[List[str]] = None
enforce_hbm: Optional[bool] = None
is_weighted: bool = False
min_partition: Optional[int] = None
num_poolings: Optional[List[float]] = None
pooling_factors: List[float]
sharding_types: Optional[List[str]] = None
stochastic_rounding: Optional[bool] = None
class torchrec.distributed.planner.types.PartitionByType(value)

Bases: Enum

Well-known partition types.

DEVICE = 'device'
HOST = 'host'
UNIFORM = 'uniform'
class torchrec.distributed.planner.types.Partitioner

Bases: ABC

Partitions shards.

Today we have multiple strategies ie. (Greedy, BLDM, Linear).

abstract partition(proposal: List[ShardingOption], storage_constraint: Topology) List[ShardingOption]
class torchrec.distributed.planner.types.Perf(fwd_compute: float, fwd_comms: float, bwd_compute: float, bwd_comms: float)

Bases: object

Representation of the breakdown of the perf estimate a single shard of an embedding table.

bwd_comms: float
bwd_compute: float
fwd_comms: float
fwd_compute: float
property total: float
class torchrec.distributed.planner.types.PerfModel

Bases: ABC

abstract rate(plan: List[ShardingOption]) float
exception torchrec.distributed.planner.types.PlannerError(message: str, error_type: PlannerErrorType = PlannerErrorType.OTHER)

Bases: Exception

class torchrec.distributed.planner.types.PlannerErrorType(value)

Bases: Enum

Classify PlannerError based on the following cases.

INSUFFICIENT_STORAGE = 'insufficient_storage'
OTHER = 'other'
PARTITION = 'partition'
STRICT_CONSTRAINTS = 'strict_constraints'
class torchrec.distributed.planner.types.Proposer

Bases: ABC

Prosposes complete lists of sharding options which can be parititioned to generate a plan.

abstract feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None) None
abstract load(search_space: List[ShardingOption]) None
abstract propose() Optional[List[ShardingOption]]
class torchrec.distributed.planner.types.Shard(size: List[int], offset: List[int], storage: Optional[Storage] = None, perf: Optional[Perf] = None, rank: Optional[int] = None)

Bases: object

Representation of a subset of an embedding table. ‘size’ and ‘offset’ fully determine the tensors in the shard. ‘storage’ is an estimation of how much it takes to store the shard with an estimation ‘perf’.

offset: List[int]
perf: Optional[Perf] = None
rank: Optional[int] = None
size: List[int]
storage: Optional[Storage] = None
class torchrec.distributed.planner.types.ShardEstimator(topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None)

Bases: ABC

Estimates shard perf or storage, requires fully specified sharding options.

abstract estimate(sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[Module]]] = None) None
class torchrec.distributed.planner.types.ShardingOption(name: str, tensor: Tensor, module: Tuple[str, Module], input_lengths: List[float], batch_size: int, sharding_type: str, partition_by: str, compute_kernel: str, shards: List[Shard], cache_params: Optional[CacheParams] = None, enforce_hbm: Optional[bool] = None, stochastic_rounding: Optional[bool] = None, bounds_check_mode: Optional[BoundsCheckMode] = None, dependency: Optional[str] = None)

Bases: object

One way of sharding an embedding table.

property fqn: str
property is_pooled: bool
property module: Tuple[str, Module]
property num_inputs: int
property num_shards: int
property path: str
property tensor: Tensor
property total_storage: Storage
class torchrec.distributed.planner.types.Stats

Bases: ABC

Logs statistics related to the sharding plan.

abstract log(sharding_plan: ShardingPlan, topology: Topology, batch_size: int, storage_reservation: StorageReservation, num_proposals: int, num_plans: int, run_time: float, best_plan: List[ShardingOption], constraints: Optional[Dict[str, ParameterConstraints]] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, debug: bool = False) None

See class description

class torchrec.distributed.planner.types.Storage(hbm: int, ddr: int)

Bases: object

Representation of the storage capacities of a hardware used in training.

ddr: int
fits_in(other: Storage) bool
hbm: int
class torchrec.distributed.planner.types.StorageReservation

Bases: ABC

Reserves storage space for non-sharded parts of the model.

abstract reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology
class torchrec.distributed.planner.types.Topology(world_size: int, compute_device: str, hbm_cap: Optional[int] = None, ddr_cap: Optional[int] = None, local_world_size: Optional[int] = None, hbm_mem_bw: float = 963146416.128, ddr_mem_bw: float = 54760833.024, intra_host_bw: float = 644245094.4, inter_host_bw: float = 13421772.8)

Bases: object

property compute_device: str
property ddr_mem_bw: float
property devices: List[DeviceHardware]
property hbm_mem_bw: float
property inter_host_bw: float
property intra_host_bw: float
property local_world_size: int
property world_size: int

torchrec.distributed.planner.utils

torchrec.distributed.planner.utils.bytes_to_gb(num_bytes: int) float
torchrec.distributed.planner.utils.bytes_to_mb(num_bytes: Union[float, int]) float
torchrec.distributed.planner.utils.gb_to_bytes(gb: float) int
torchrec.distributed.planner.utils.placement(compute_device: str, rank: int, local_size: int) str

Returns placement, formatted as string

torchrec.distributed.planner.utils.prod(iterable: Iterable[int]) int
torchrec.distributed.planner.utils.reset_shard_rank(proposal: List[ShardingOption]) None
torchrec.distributed.planner.utils.sharder_name(t: Type[Any]) str
torchrec.distributed.planner.utils.storage_repr_in_gb(storage: Optional[Storage]) str

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources