• Docs >
  • torchrec.distributed.planner
Shortcuts

torchrec.distributed.planner

Torchrec Planner

The planner provides the specifications necessary for a module to be sharded, considering the possible options to build an optimized plan.

The features includes:
  • generating all possible sharding options.

  • estimating perf and storage for every shard.

  • estimating peak memory usage to eliminate sharding plans that might OOM.

  • customizability for parameter constraints, partitioning, proposers, or performance modeling.

  • automatically building and selecting an optimized sharding plan.

torchrec.distributed.planner.constants

torchrec.distributed.planner.constants.kernel_bw_lookup(compute_device: str, compute_kernel: str, hbm_mem_bw: float, ddr_mem_bw: float, caching_ratio: Optional[float] = None, prefetch_pipeline: bool = False) Optional[float]

Calculates the device bandwidth based on given compute device, compute kernel, and caching ratio.

Parameters:
  • compute_kernel (str) – compute kernel.

  • compute_device (str) – compute device.

  • hbm_mem_bw (float) – the bandwidth of the device HBM.

  • ddr_mem_bw (float) – the bandwidth of the system DDR memory.

  • caching_ratio (Optional[float]) – caching ratio used to determine device bandwidth if UVM caching is enabled.

  • prefetch_pipeline (bool) – whether prefetch pipeline is enabled.

Returns:

the device bandwidth.

Return type:

Optional[float]

torchrec.distributed.planner.enumerators

class torchrec.distributed.planner.enumerators.EmbeddingEnumerator(topology: Topology, batch_size: int, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None)

Bases: Enumerator

Generates embedding sharding options for given nn.Module, considering user provided constraints.

Parameters:
  • topology (Topology) – device topology.

  • batch_size (int) – batch size.

  • constraints (Optional[Dict[str, ParameterConstraints]]) – dict of parameter names to provided ParameterConstraints.

enumerate(module: Module, sharders: List[ModuleSharder[Module]]) List[ShardingOption]

Generates relevant sharding options given module and sharders.

Parameters:
  • module (nn.Module) – module to be sharded.

  • sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.

Returns:

valid sharding options with values populated.

Return type:

List[ShardingOption]

populate_estimates(sharding_options: List[ShardingOption]) None

See class description.

torchrec.distributed.planner.enumerators.get_partition_by_type(sharding_type: str) str

Gets corresponding partition by type for provided sharding type.

Parameters:

sharding_type (str) – sharding type string.

Returns:

the corresponding PartitionByType value.

Return type:

str

torchrec.distributed.planner.partitioners

class torchrec.distributed.planner.partitioners.GreedyPerfPartitioner(sort_by: SortBy = SortBy.STORAGE, balance_modules: bool = False)

Bases: Partitioner

Greedy Partitioner

Parameters:
  • sort_by (SortBy) – Sort sharding options by storage or perf in descending order (i.e., large tables will be placed first).

  • balance_modules (bool) – Whether to sort by modules first, where smaller modules will be sorted first. In effect, this will place tables in each module in a balanced way.

partition(proposal: List[ShardingOption], storage_constraint: Topology) List[ShardingOption]

Places sharding options on topology based on each sharding option’s partition_by attribute. The topology, storage, and perfs are updated at the end of the placement.

Parameters:
  • proposal (List[ShardingOption]) – list of populated sharding options.

  • storage_constraint (Topology) – device topology.

Returns:

list of sharding options for selected plan.

Return type:

List[ShardingOption]

Example:

sharding_options = [
        ShardingOption(partition_by="uniform",
                shards=[
                    Shards(storage=1, perf=1),
                    Shards(storage=1, perf=1),
                ]),
        ShardingOption(partition_by="uniform",
                shards=[
                    Shards(storage=2, perf=2),
                    Shards(storage=2, perf=2),
                ]),
        ShardingOption(partition_by="device",
                shards=[
                    Shards(storage=3, perf=3),
                    Shards(storage=3, perf=3),
                ])
        ShardingOption(partition_by="device",
                shards=[
                    Shards(storage=4, perf=4),
                    Shards(storage=4, perf=4),
                ]),
    ]
topology = Topology(world_size=2)

# First [sharding_options[0] and sharding_options[1]] will be placed on the
# topology with the uniform strategy, resulting in

topology.devices[0].perf.total = (1,2)
topology.devices[1].perf.total = (1,2)

# Finally sharding_options[2] and sharding_options[3]] will be placed on the
# topology with the device strategy (see docstring of `partition_by_device` for
# more details).

topology.devices[0].perf.total = (1,2) + (3,4)
topology.devices[1].perf.total = (1,2) + (3,4)

# The topology updates are done after the end of all the placements (the other
# in the example is just for clarity).
class torchrec.distributed.planner.partitioners.MemoryBalancedPartitioner(max_search_count: int = 10, tolerance: float = 0.02, balance_modules: bool = False)

Bases: Partitioner

Memory balanced Partitioner.

Parameters:
  • max_search_count (int) – Maximum number of times to call the GreedyPartitioner.

  • tolerance (float) – The maximum acceptable difference between the original plan and the new plan. If tolerance is 1, that means a new plan will be rejected if its perf is 200% of the original plan (i.e., the plan is 100% worse).

  • balance_modules (bool) – Whether to sort by modules first, where smaller modules will be sorted first. In effect, this will place tables in each module in a balanced way.

partition(proposal: List[ShardingOption], storage_constraint: Topology) List[ShardingOption]

Repeatedly calls the GreedyPerfPartitioner to find a plan with perf within the tolerance of the original plan that uses the least amount of memory.

class torchrec.distributed.planner.partitioners.OrderedDeviceHardware(device: torchrec.distributed.planner.types.DeviceHardware, local_world_size: int)

Bases: object

device: DeviceHardware
local_world_size: int
class torchrec.distributed.planner.partitioners.ShardingOptionGroup(sharding_options: List[torchrec.distributed.planner.types.ShardingOption], storage_sum: torchrec.distributed.planner.types.Storage, perf_sum: float, param_count: int)

Bases: object

param_count: int
perf_sum: float
sharding_options: List[ShardingOption]
storage_sum: Storage
class torchrec.distributed.planner.partitioners.SortBy(value)

Bases: Enum

An enumeration.

PERF = 'perf'
STORAGE = 'storage'
torchrec.distributed.planner.partitioners.set_hbm_per_device(storage_constraint: Topology, hbm_per_device: int) None

torchrec.distributed.planner.perf_models

class torchrec.distributed.planner.perf_models.NoopPerfModel(topology: Topology)

Bases: PerfModel

A no-op model that returns the maximum perf among all shards. Here, no-op means we estimate the performance of a model without actually running it.

rate(plan: List[ShardingOption]) float
class torchrec.distributed.planner.perf_models.NoopStorageModel(topology: Topology)

Bases: PerfModel

A no-op model that returns the maximum hbm usage among all shards. Here, no-op means we estimate the performance of a model without actually running it.

rate(plan: List[ShardingOption]) float

torchrec.distributed.planner.planners

class torchrec.distributed.planner.planners.EmbeddingShardingPlanner(topology: Optional[Topology] = None, batch_size: Optional[int] = None, enumerator: Optional[Enumerator] = None, storage_reservation: Optional[StorageReservation] = None, proposer: Optional[Union[Proposer, List[Proposer]]] = None, partitioner: Optional[Partitioner] = None, performance_model: Optional[PerfModel] = None, stats: Optional[Union[Stats, List[Stats]]] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, debug: bool = True)

Bases: ShardingPlanner

Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints.

collective_plan(module: Module, sharders: Optional[List[ModuleSharder[Module]]] = None, pg: Optional[ProcessGroup] = None) ShardingPlan

Call self.plan(…) on rank 0 and broadcast

plan(module: Module, sharders: List[ModuleSharder[Module]]) ShardingPlan

Plans sharding for provided module and given sharders.

Parameters:
  • module (nn.Module) – module that sharding is planned for.

  • sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.

Returns:

the computed sharding plan.

Return type:

ShardingPlan

class torchrec.distributed.planner.planners.HeteroEmbeddingShardingPlanner(topology_groups: Optional[Dict[str, Topology]] = None, batch_size: Optional[int] = None, enumerators: Optional[Dict[str, Enumerator]] = None, storage_reservations: Optional[Dict[str, StorageReservation]] = None, proposers: Optional[Dict[str, Union[Proposer, List[Proposer]]]] = None, partitioners: Optional[Dict[str, Partitioner]] = None, performance_models: Optional[Dict[str, PerfModel]] = None, stats: Optional[Dict[str, Union[Stats, List[Stats]]]] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, debug: bool = True)

Bases: ShardingPlanner

Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints.

collective_plan(module: Module, sharders: Optional[List[ModuleSharder[Module]]] = None, pg: Optional[ProcessGroup] = None) ShardingPlan

Call self.plan(…) on rank 0 and broadcast

plan(module: Module, sharders: List[ModuleSharder[Module]]) ShardingPlan

Plans sharding for provided module and given sharders.

Parameters:
  • module (nn.Module) – module that sharding is planned for.

  • sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.

Returns:

the computed sharding plan.

Return type:

ShardingPlan

torchrec.distributed.planner.proposers

class torchrec.distributed.planner.proposers.EmbeddingOffloadScaleupProposer(use_depth: bool = True)

Bases: Proposer

static allocate_budget(model: Tensor, clfs: Tensor, budget: int, allocation_priority: Tensor) Tensor
static build_affine_storage_model(uvm_caching_sharding_options: List[ShardingOption], enumerator: Enumerator) Tensor
static clf_to_bytes(model: Tensor, clfs: Union[float, Tensor]) Tensor
feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None) None
static get_budget(proposal: List[ShardingOption], storage_constraint: Topology) int

returns additional HBM budget available for GPU caches.

static get_cacheability(sharding_option: ShardingOption) Optional[float]
static get_expected_lookups(sharding_option: ShardingOption) Optional[float]
load(search_space: List[ShardingOption], enumerator: Optional[Enumerator] = None) None
static next_plan(starting_proposal: List[ShardingOption], budget: Optional[int], enumerator: Optional[Enumerator]) Optional[List[ShardingOption]]
propose() Optional[List[ShardingOption]]
class torchrec.distributed.planner.proposers.GreedyProposer(use_depth: bool = True, threshold: Optional[int] = None)

Bases: Proposer

Proposes sharding plans in greedy fashion.

Sorts sharding options for each shardable parameter by perf. On each iteration, finds parameter with largest current storage usage and tries its next sharding option.

Parameters:
  • use_depth (bool) – When enabled, sharding_options of a fqn are sorted based on max(shard.perf.total), otherwise sharding_options are sorted by sum(shard.perf.total).

  • threshold (Optional[int]) – Threshold for early stopping. When specified, the proposer stops proposing when the proposals have consecutive worse perf_rating than best_perf_rating.

feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None) None
load(search_space: List[ShardingOption], enumerator: Optional[Enumerator] = None) None
propose() Optional[List[ShardingOption]]
class torchrec.distributed.planner.proposers.GridSearchProposer(max_proposals: int = 10000)

Bases: Proposer

feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None) None
load(search_space: List[ShardingOption], enumerator: Optional[Enumerator] = None) None
propose() Optional[List[ShardingOption]]
class torchrec.distributed.planner.proposers.UniformProposer(use_depth: bool = True)

Bases: Proposer

Proposes uniform sharding plans, plans that have the same sharding type for all sharding options.

feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None) None
load(search_space: List[ShardingOption], enumerator: Optional[Enumerator] = None) None
propose() Optional[List[ShardingOption]]
torchrec.distributed.planner.proposers.proposers_to_proposals_list(proposers_list: List[Proposer], search_space: List[ShardingOption]) List[List[ShardingOption]]

only works for static_feedback proposers (the path of proposals to check is independent of the performance of the proposals)

torchrec.distributed.planner.shard_estimators

class torchrec.distributed.planner.shard_estimators.EmbeddingOffloadStats(cacheability: float, expected_lookups: int, mrc_hist_counts: Tensor, height: int)

Bases: CacheStatistics

Computes cache statistics for uvm_fused_cache tables.

Args:

cachebility (float):

The area-under-the-curve of miss-ratio curve.

expected_lookups (float):

The expected number of unique embedding ids per global batch.

mrc_hist_counts (torch.Tensor):

A 1d tensor (size n) holding a histogram of LRU miss ratio curve. Each bin represents 1/nth of possible LRU cache sizes (from load_factor 0 to load_factor 1.0). The bin contains the number of expected LRU operations that could be handled without a cache miss if the LRU load_factor was at least that size.

height (int):

The height (num_embeddings) of the embedding table.

property cacheability: float

Summarized measure of the difficulty to cache a dataset that is independent of cache size. A score of 0 means the dataset is very cacheable (e.g. high locality between accesses), a score of 1 is very difficult to cache.

static estimate_cache_miss_rate(cache_sizes: Tensor, hist: Tensor, bins: Tensor) Tensor

Calculate estimated cache miss ratio for the proposed cache_sizes, given the MRC histogram.

property expected_lookups: int

Number of expected cache lookups per training step.

This is the expected number of distinct values in a global training batch.

expected_miss_rate(clf: float) float

Expected cache lookup miss rate for a given cache size.

When clf (cache load factor) is 0, returns 1.0 (100% miss). When clf is 1.0, returns 0 (100% hit). For values of clf between these extremes, returns the estimated miss rate of the cache, e.g. based on knowledge of the statistical properties of the training data set.

class torchrec.distributed.planner.shard_estimators.EmbeddingPerfEstimator(topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None, is_inference: bool = False)

Bases: ShardEstimator

Embedding Wall Time Perf Estimator

estimate(sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[Module]]] = None) None
classmethod perf_func_emb_wall_time(shard_sizes: List[List[int]], compute_kernel: str, compute_device: str, sharding_type: str, batch_sizes: List[int], world_size: int, local_world_size: int, input_lengths: List[float], input_data_type_size: float, table_data_type_size: float, output_data_type_size: float, fwd_a2a_comm_data_type_size: float, bwd_a2a_comm_data_type_size: float, fwd_sr_comm_data_type_size: float, bwd_sr_comm_data_type_size: float, num_poolings: List[float], hbm_mem_bw: float, ddr_mem_bw: float, intra_host_bw: float, inter_host_bw: float, bwd_compute_multiplier: float, is_pooled: bool, is_weighted: bool = False, caching_ratio: Optional[float] = None, is_inference: bool = False, prefetch_pipeline: bool = False, expected_cache_fetches: float = 0) List[Perf]

Attempts to model perfs as a function of relative wall times.

Parameters:
  • shard_sizes (List[List[int]]) – the list of (local_rows, local_cols) of each shard.

  • compute_kernel (str) – compute kernel.

  • compute_device (str) – compute device.

  • sharding_type (str) – tw, rw, cw, twrw, dp.

  • batch_sizes (List[int]) – batch size for each input feature.

  • world_size (int) – the number of devices for all hosts.

  • local_world_size (int) – the number of the device for each host.

  • input_lengths (List[float]) – the list of the average number of lookups of each input query feature.

  • input_data_type_size (float) – the data type size of the distributed data_parallel input.

  • table_data_type_size (float) – the data type size of the table.

  • output_data_type_size (float) – the data type size of the output embeddings.

  • fwd_comm_data_type_size (float) – the data type size of the distributed data_parallel input during forward communication.

  • bwd_comm_data_type_size (float) – the data type size of the distributed data_parallel input during backward communication.

  • num_poolings (List[float]) – number of poolings per sample, typically 1.0.

  • hbm_mem_bw (float) – the bandwidth of the device HBM.

  • ddr_mem_bw (float) – the bandwidth of the system DDR memory.

  • intra_host_bw (float) – the bandwidth within a single host like multiple threads.

  • inter_host_bw (float) – the bandwidth between two hosts like multiple machines.

  • is_pooled (bool) – True if embedding output is pooled (ie. EmbeddingBag), False if unpooled/sequential (ie. Embedding).

  • is_weighted (bool = False) – if the module is an EBC and is weighted, typically signifying an id score list feature.

  • is_inference (bool = False) – if planning for inference.

  • caching_ratio (Optional[float] = None) – cache ratio to determine the bandwidth of device.

  • prefetch_pipeline (bool = False) – whether prefetch pipeline is enabled.

  • expected_cache_fetches (float) – number of expected cache fetches across global batch

Returns:

the list of perf for each shard.

Return type:

List[float]

class torchrec.distributed.planner.shard_estimators.EmbeddingStorageEstimator(topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None)

Bases: ShardEstimator

Embedding Storage Usage Estimator

estimate(sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[Module]]] = None) None
torchrec.distributed.planner.shard_estimators.calculate_shard_storages(sharder: ModuleSharder[Module], sharding_type: str, tensor: Tensor, compute_device: str, compute_kernel: str, shard_sizes: List[List[int]], batch_sizes: List[int], world_size: int, local_world_size: int, input_lengths: List[float], num_poolings: List[float], caching_ratio: float, is_pooled: bool, input_data_type_size: float, output_data_type_size: float) List[Storage]

Calculates estimated storage sizes for each sharded tensor, comprised of input, output, tensor, gradient, and optimizer sizes.

Parameters:
  • sharder (ModuleSharder[nn.Module]) – sharder for module that supports sharding.

  • sharding_type (str) – provided ShardingType value.

  • tensor (torch.Tensor) – tensor to be sharded.

  • compute_device (str) – compute device to be used.

  • compute_kernel (str) – compute kernel to be used.

  • shard_sizes (List[List[int]]) – list of dimensions of each sharded tensor.

  • batch_sizes (List[int]) – batch size for each input feature.

  • world_size (int) – total number of devices in topology.

  • local_world_size (int) – total number of devices in host group topology.

  • input_lengths (List[float]) – average input lengths synonymous with pooling factors.

  • num_poolings (List[float]) – average number of poolings per sample (typically 1.0).

  • caching_ratio (float) – ratio of HBM to DDR memory for UVM caching.

  • is_pooled (bool) – True if embedding output is pooled (ie. EmbeddingBag), False if unpooled/sequential (ie. Embedding).

  • input_data_type_size (int) – number of bytes of input data type.

  • output_data_type_size (int) – number of bytes of output data type.

Returns:

storage object for each device in topology.

Return type:

List[Storage]

torchrec.distributed.planner.stats

class torchrec.distributed.planner.stats.EmbeddingStats

Bases: Stats

Stats for a sharding planner execution.

log(sharding_plan: ShardingPlan, topology: Topology, batch_size: int, storage_reservation: StorageReservation, num_proposals: int, num_plans: int, run_time: float, best_plan: List[ShardingOption], constraints: Optional[Dict[str, ParameterConstraints]] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, debug: bool = True) None

Logs stats for a given sharding plan.

Provides a tabular view of stats for the given sharding plan with per device storage usage (HBM and DDR), perf, input, output, and number/type of shards.

Parameters:
  • sharding_plan (ShardingPlan) – sharding plan chosen by the planner.

  • topology (Topology) – device topology.

  • batch_size (int) – batch size.

  • storage_reservation (StorageReservation) – reserves storage for unsharded parts of the model

  • num_proposals (int) – number of proposals evaluated.

  • num_plans (int) – number of proposals successfully partitioned.

  • run_time (float) – time taken to find plan (in seconds).

  • best_plan (List[ShardingOption]) – plan with expected performance.

  • constraints (Optional[Dict[str, ParameterConstraints]]) – dict of parameter names to provided ParameterConstraints.

  • debug (bool) – whether to enable debug mode.

class torchrec.distributed.planner.stats.NoopEmbeddingStats

Bases: Stats

Noop Stats for a sharding planner execution.

log(sharding_plan: ShardingPlan, topology: Topology, batch_size: int, storage_reservation: StorageReservation, num_proposals: int, num_plans: int, run_time: float, best_plan: List[ShardingOption], constraints: Optional[Dict[str, ParameterConstraints]] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, debug: bool = True) None

See class description

torchrec.distributed.planner.stats.round_to_one_sigfig(x: float) str

torchrec.distributed.planner.storage_reservations

class torchrec.distributed.planner.storage_reservations.FixedPercentageStorageReservation(percentage: float)

Bases: StorageReservation

reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology
class torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation(percentage: float, parameter_multiplier: float = 6.0, dense_tensor_estimate: Optional[int] = None)

Bases: StorageReservation

Reserves storage for model to be sharded with heuristical calculation. The storage reservation is comprised of dense tensor storage, KJT storage, and an extra percentage of total storage.

Parameters:
  • percentage (float) – extra storage percent to reserve that acts as a margin of error beyond heuristic calculation of storage.

  • parameter_multiplier (float) – heuristic multiplier for total parameter storage.

  • dense_tensor_estimate (Optional[int]) – storage estimate for dense tensors, uses default heuristic estimate if not provided.

reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology
class torchrec.distributed.planner.storage_reservations.InferenceStorageReservation(percentage: float, dense_tensor_estimate: Optional[int] = None)

Bases: StorageReservation

Reserves storage for model to be sharded for inference. The storage reservation is comprised of dense tensor storage, KJT storage, and an extra percentage of total storage. Note that when estimating for storage, dense modules are assumed to be on GPUs and replicated across ranks. If this is not the case, please override the estimates with dense_tensor_estimate.

Parameters:
  • percentage (float) – extra storage percentage to reserve that acts as a margin of error beyond storage calculation.

  • dense_tensor_estimate (Optional[int]) – storage estimate for dense tensors, use default heuristic estimate if not provided.

reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology

torchrec.distributed.planner.types

class torchrec.distributed.planner.types.CustomTopologyData(data: Dict[str, List[int]], world_size: int)

Bases: object

Custom device data for individual device in a topology.

get_data(key: str) List[int]
has_data(key: str) bool
supported_fields = ['ddr_cap', 'hbm_cap']
class torchrec.distributed.planner.types.DeviceHardware(rank: int, storage: Storage, perf: Perf)

Bases: object

Representation of a device in a process group. ‘perf’ is an estimation of network, CPU, and storage usages.

perf: Perf
rank: int
storage: Storage
class torchrec.distributed.planner.types.Enumerator(topology: Topology, batch_size: int = 512, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None)

Bases: ABC

Generates all relevant sharding options for given topology, constraints, nn.Module, and sharders.

abstract enumerate(module: Module, sharders: List[ModuleSharder[Module]]) List[ShardingOption]

See class description.

abstract populate_estimates(sharding_options: List[ShardingOption]) None

See class description.

class torchrec.distributed.planner.types.ParameterConstraints(sharding_types: ~typing.Optional[~typing.List[str]] = None, compute_kernels: ~typing.Optional[~typing.List[str]] = None, min_partition: ~typing.Optional[int] = None, pooling_factors: ~typing.List[float] = <factory>, num_poolings: ~typing.Optional[~typing.List[float]] = None, batch_sizes: ~typing.Optional[~typing.List[int]] = None, is_weighted: bool = False, cache_params: ~typing.Optional[~torchrec.distributed.types.CacheParams] = None, enforce_hbm: ~typing.Optional[bool] = None, stochastic_rounding: ~typing.Optional[bool] = None, bounds_check_mode: ~typing.Optional[~fbgemm_gpu.split_table_batched_embeddings_ops_common.BoundsCheckMode] = None, feature_names: ~typing.Optional[~typing.List[str]] = None, output_dtype: ~typing.Optional[~torchrec.types.DataType] = None, device_group: ~typing.Optional[str] = None)

Bases: object

Stores user provided constraints around the sharding plan.

If provided, pooling_factors, num_poolings, and batch_sizes must match in length, as per sample.

sharding_types

sharding types allowed for the table. Values of enum ShardingType.

Type:

Optional[List[str]]

compute_kernels

compute kernels allowed for the table. Values of enum EmbeddingComputeKernel.

Type:

Optional[List[str]]

min_partition

lower bound for dimension of column wise shards. Planner will search for the column wise shard dimension in the range of [min_partition, embedding_dim], as long as the column wise shard dimension divides embedding_dim and is divisible by 4. Used for column wise sharding only.

Type:

Optional[int]

pooling_factors

pooling factors for each feature of the table. This is the average number of values each sample has for the feature. Length of pooling_factors should match the number of features.

Type:

Optional[List[float]]

num_poolings

number of poolings for each feature of the table. Length of num_poolings should match the number of features.

Type:

OptionalList[float]]

batch_sizes

batch sizes for each feature of the table. Length of batch_sizes should match the number of features.

Type:

Optional[List[int]]

is_weighted

whether the table is weighted.

Type:

Optional[bool]

cache_params

cache parameters to be used by this table. These are passed to FBGEMM’s Split TBE kernel.

Type:

Optional[CacheParams]

enforce_hbm

whether to place all weights/momentums in HBM when using cache.

Type:

Optional[bool]

stochastic_rounding

whether to do stochastic rounding. This is passed to FBGEMM’s Split TBE kernel. Stochastic rounding is non-deterministic, but important to maintain accuracy in longer term with FP16 embedding tables.

Type:

Optional[bool]

bounds_check_mode

bounds check mode to be used by FBGEMM’s Split TBE kernel. Bounds check means checking if values (i.e. row id) is within the table size. If row id exceeds table size, it will be set to 0.

Type:

Optional[BoundsCheckMode]

feature_names

list of feature names for this table.

Type:

Optional[List[str]]

output_dtype

output dtype to be used by this table. The default is FP32. If not None, the output dtype will also be used by the planner to produce a more balanced plan.

Type:

Optional[DataType]

device_group

device group to be used by this table. It can be cpu or cuda. This specifies if the table should be placed on a cpu device or a gpu device.

Type:

Optional[str]

batch_sizes: Optional[List[int]] = None
bounds_check_mode: Optional[BoundsCheckMode] = None
cache_params: Optional[CacheParams] = None
compute_kernels: Optional[List[str]] = None
device_group: Optional[str] = None
enforce_hbm: Optional[bool] = None
feature_names: Optional[List[str]] = None
is_weighted: bool = False
min_partition: Optional[int] = None
num_poolings: Optional[List[float]] = None
output_dtype: Optional[DataType] = None
pooling_factors: List[float]
sharding_types: Optional[List[str]] = None
stochastic_rounding: Optional[bool] = None
class torchrec.distributed.planner.types.PartitionByType(value)

Bases: Enum

Well-known partition types.

DEVICE = 'device'
HOST = 'host'
UNIFORM = 'uniform'
class torchrec.distributed.planner.types.Partitioner

Bases: ABC

Partitions shards.

Today we have multiple strategies ie. (Greedy, BLDM, Linear).

abstract partition(proposal: List[ShardingOption], storage_constraint: Topology) List[ShardingOption]
class torchrec.distributed.planner.types.Perf(fwd_compute: float, fwd_comms: float, bwd_compute: float, bwd_comms: float, prefetch_compute: float = 0.0)

Bases: object

Representation of the breakdown of the perf estimate a single shard of an embedding table.

bwd_comms: float
bwd_compute: float
fwd_comms: float
fwd_compute: float
prefetch_compute: float = 0.0
property total: float
class torchrec.distributed.planner.types.PerfModel

Bases: ABC

abstract rate(plan: List[ShardingOption]) float
exception torchrec.distributed.planner.types.PlannerError(message: str, error_type: PlannerErrorType = PlannerErrorType.OTHER)

Bases: Exception

class torchrec.distributed.planner.types.PlannerErrorType(value)

Bases: Enum

Classify PlannerError based on the following cases.

INSUFFICIENT_STORAGE = 'insufficient_storage'
OTHER = 'other'
PARTITION = 'partition'
STRICT_CONSTRAINTS = 'strict_constraints'
class torchrec.distributed.planner.types.Proposer

Bases: ABC

Prosposes complete lists of sharding options which can be parititioned to generate a plan.

abstract feedback(partitionable: bool, plan: Optional[List[ShardingOption]] = None, perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None) None
abstract load(search_space: List[ShardingOption], enumerator: Optional[Enumerator] = None) None
abstract propose() Optional[List[ShardingOption]]
class torchrec.distributed.planner.types.Shard(size: List[int], offset: List[int], storage: Optional[Storage] = None, perf: Optional[Perf] = None, rank: Optional[int] = None)

Bases: object

Representation of a subset of an embedding table. ‘size’ and ‘offset’ fully determine the tensors in the shard. ‘storage’ is an estimation of how much it takes to store the shard with an estimation ‘perf’.

offset: List[int]
perf: Optional[Perf] = None
rank: Optional[int] = None
size: List[int]
storage: Optional[Storage] = None
class torchrec.distributed.planner.types.ShardEstimator(topology: Topology, constraints: Optional[Dict[str, ParameterConstraints]] = None)

Bases: ABC

Estimates shard perf or storage, requires fully specified sharding options.

abstract estimate(sharding_options: List[ShardingOption], sharder_map: Optional[Dict[str, ModuleSharder[Module]]] = None) None
class torchrec.distributed.planner.types.ShardingOption(name: str, tensor: Tensor, module: Tuple[str, Module], input_lengths: List[float], batch_size: int, sharding_type: str, partition_by: str, compute_kernel: str, shards: List[Shard], cache_params: Optional[CacheParams] = None, enforce_hbm: Optional[bool] = None, stochastic_rounding: Optional[bool] = None, bounds_check_mode: Optional[BoundsCheckMode] = None, dependency: Optional[str] = None, is_pooled: Optional[bool] = None, feature_names: Optional[List[str]] = None, output_dtype: Optional[DataType] = None)

Bases: object

One way of sharding an embedding table. In the enumerator, we generate multiple sharding options per table, but in the planner output, there should only be one sharding option per table.

name

name of the sharding option.

Type:

str

tensor

tensor of the sharding option. Usually on meta device.

Type:

torch.Tensor

module

module and its fqn that contains the table.

Type:

Tuple[str, nn.Module]

input_lengths

list of pooling factors of the feature for the table.

Type:

List[float]

batch_size

batch size of training / eval job.

Type:

int

sharding_type

sharding type of the table. Value of enum ShardingType.

Type:

str

compute_kernel

compute kernel of the table. Value of enum EmbeddingComputeKernel.

Type:

str

shards

list of shards of the table.

Type:

List[Shard]

cache_params

cache parameters to be used by this table. These are passed to FBGEMM’s Split TBE kernel.

Type:

Optional[CacheParams]

enforce_hbm

whether to place all weights/momentums in HBM when using cache.

Type:

Optional[bool]

stochastic_rounding

whether to do stochastic rounding. This is passed to FBGEMM’s Split TBE kernel. Stochastic rounding is non-deterministic, but important to maintain accuracy in longer term with FP16 embedding tables.

Type:

Optional[bool]

bounds_check_mode

bounds check mode to be used by FBGEMM’s Split TBE kernel. Bounds check means checking if values (i.e. row id) is within the table size. If row id exceeds table size, it will be set to 0.

Type:

Optional[BoundsCheckMode]

dependency

dependency of the table. Related to Embedding tower.

Type:

Optional[str]

is_pooled

whether the table is pooled. Pooling can be sum pooling or mean pooling. Unpooled tables are also known as sequence embeddings.

Type:

Optional[bool]

feature_names

list of feature names for this table.

Type:

Optional[List[str]]

output_dtype

output dtype to be used by this table. The default is FP32. If not None, the output dtype will also be used by the planner to produce a more balanced plan.

Type:

Optional[DataType]

property cache_load_factor: Optional[float]
property fqn: str
property is_pooled: bool
property module: Tuple[str, Module]
static module_pooled(module: Module, sharding_option_name: str) bool

Determine if module pools output (e.g. EmbeddingBag) or uses unpooled/sequential output.

property num_inputs: int
property num_shards: int
property path: str
property tensor: Tensor
property total_perf: float
property total_storage: Storage
class torchrec.distributed.planner.types.Stats

Bases: ABC

Logs statistics related to the sharding plan.

abstract log(sharding_plan: ShardingPlan, topology: Topology, batch_size: int, storage_reservation: StorageReservation, num_proposals: int, num_plans: int, run_time: float, best_plan: List[ShardingOption], constraints: Optional[Dict[str, ParameterConstraints]] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, debug: bool = False) None

See class description

class torchrec.distributed.planner.types.Storage(hbm: int, ddr: int)

Bases: object

Representation of the storage capacities of a hardware used in training.

ddr: int
fits_in(other: Storage) bool
hbm: int
class torchrec.distributed.planner.types.StorageReservation

Bases: ABC

Reserves storage space for non-sharded parts of the model.

abstract reserve(topology: Topology, batch_size: int, module: Module, sharders: List[ModuleSharder[Module]], constraints: Optional[Dict[str, ParameterConstraints]] = None) Topology
class torchrec.distributed.planner.types.Topology(world_size: int, compute_device: str, hbm_cap: Optional[int] = None, ddr_cap: Optional[int] = None, local_world_size: Optional[int] = None, hbm_mem_bw: float = 963146416.128, ddr_mem_bw: float = 54760833.024, intra_host_bw: float = 644245094.4, inter_host_bw: float = 13421772.8, bwd_compute_multiplier: float = 2, custom_topology_data: Optional[CustomTopologyData] = None)

Bases: object

property bwd_compute_multiplier: float
property compute_device: str
property ddr_mem_bw: float
property devices: List[DeviceHardware]
property hbm_mem_bw: float
property inter_host_bw: float
property intra_host_bw: float
property local_world_size: int
property world_size: int

torchrec.distributed.planner.utils

class torchrec.distributed.planner.utils.BinarySearchPredicate(A: int, B: int, tolerance: int)

Bases: object

Generates values of X between A & B to invoke on an external predicate F(X) to discover the largest X for which F(X) is true. Uses binary search to minimize the number of invocations of F. Assumes F is a step function, i.e. if F(X) is false, there is no point trying F(X+1).

next(prior_result: bool) Optional[int]

next() returns the next value to probe, given the result of the prior probe. The first time next() is invoked the prior_result is ignored. Returns None if entire range explored or threshold reached.

class torchrec.distributed.planner.utils.LuusJaakolaSearch(A: float, B: float, max_iterations: int, seed: int = 42, left_cost: Optional[float] = None)

Bases: object

Implements a clamped variant of Luus Jaakola search.

See https://en.wikipedia.org/wiki/Luus-Jaakola.

best() Tuple[float, float]

Return the best position so far, and its associated cost.

clamp(x: float) float

Clamp x into range [left, right]

next(fy: float) Optional[float]

Return the next probe point ‘y’ to evaluate, given the previous result.

The first time around fy is ignored. Subsequent invocations should provide the result of evaluating the function being minimized, i.e. f(y).

Returns None when the maximum number of iterations has been reached.

shrink_right(B: float) None

Shrink right boundary given [B,infinity) -> infinity

uniform(A: float, B: float) float

Return a random uniform position in range [A,B].

torchrec.distributed.planner.utils.bytes_to_gb(num_bytes: int) float
torchrec.distributed.planner.utils.bytes_to_mb(num_bytes: Union[float, int]) float
torchrec.distributed.planner.utils.gb_to_bytes(gb: float) int
torchrec.distributed.planner.utils.placement(compute_device: str, rank: int, local_size: int) str

Returns placement, formatted as string

torchrec.distributed.planner.utils.prod(iterable: Iterable[int]) int
torchrec.distributed.planner.utils.reset_shard_rank(proposal: List[ShardingOption]) None
torchrec.distributed.planner.utils.sharder_name(t: Type[Any]) str
torchrec.distributed.planner.utils.storage_repr_in_gb(storage: Optional[Storage]) str

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources