• Docs >
  • torchrec.distributed.planner
Shortcuts

torchrec.distributed.planner

Torchrec Planner

The planner provides the specifications necessary for a module to be sharded, considering the possible options to build an optimized plan.

The features includes:
  • generating all possible sharding options.

  • estimating perf and storage for every shard.

  • estimating peak memory usage to eliminate sharding plans that might OOM.

  • customizability for parameter constraints, partitioning, proposers, or performance modeling.

  • automatically building and selecting an optimized sharding plan.

torchrec.distributed.planner.constants

torchrec.distributed.planner.constants.kernel_bw_lookup(compute_device: str, compute_kernel: str, caching_ratio: Optional[float] = None) Optional[float]

Calculates the device bandwidth based on given compute device, compute kernel, and caching ratio.

Parameters
  • compute_kernel (str) – compute kernel.

  • compute_device (str) – compute device.

  • caching_ratio (Optional[float]) – caching ratio used to determine device bandwidth if UVM caching is enabled.

Returns

the device bandwidth.

Return type

float

torchrec.distributed.planner.enumerators

class torchrec.distributed.planner.enumerators.EmbeddingEnumerator(topology: torchrec.distributed.planner.types.Topology, constraints: Optional[Dict[str, torchrec.distributed.planner.types.ParameterConstraints]] = None, estimator: Optional[Union[torchrec.distributed.planner.types.ShardEstimator, List[torchrec.distributed.planner.types.ShardEstimator]]] = None)

Bases: torchrec.distributed.planner.types.Enumerator

Generates embedding sharding options for given nn.Module, considering user provided constraints.

Parameters
  • topology (Topology) – device topology.

  • constraints (Optional[Dict[str, ParameterConstraints]]) – dict of parameter names to provided ParameterConstraints.

enumerate(module: torch.nn.modules.module.Module, sharders: List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]) List[torchrec.distributed.planner.types.ShardingOption]

Generates relevant sharding options given module and sharders.

Parameters
  • module (nn.Module) – module to be sharded.

  • sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.

Returns

valid sharding options with values populated.

Return type

List[ShardingOption]

torchrec.distributed.planner.enumerators.calculate_shard_sizes_and_offsets(tensor: torch.Tensor, world_size: int, local_world_size: int, sharding_type: str, col_wise_shard_dim: Optional[int] = None) Tuple[List[List[int]], List[List[int]]]

Calculates sizes and offsets for tensor sharded according to provided sharding type.

Parameters
  • tensor (torch.Tensor) – tensor to be sharded.

  • world_size (int) – total number of devices in topology.

  • local_world_size (int) – total number of devices in host group topology.

  • sharding_type (str) – provided ShardingType value.

  • col_wise_shard_dim (Optional[int]) – dimension for column wise sharding split.

Returns

shard sizes, represented as a list of the dimensions of the sharded tensor on each device, and shard offsets, represented as a list of coordinates of placement on each device.

Return type

Tuple[List[List[int]], List[List[int]]]

Raises

ValueError – If sharding_type is not a valid ShardingType.

torchrec.distributed.planner.enumerators.get_partition_by_type(sharding_type: str) str

Gets corresponding partition by type for provided sharding type.

Parameters

sharding_type (str) – sharding type string.

Returns

the corresponding PartitionByType value.

Return type

str

torchrec.distributed.planner.partitioners

class torchrec.distributed.planner.partitioners.GreedyPerfPartitioner

Bases: torchrec.distributed.planner.types.Partitioner

Greedy Partitioner

partition(proposal: List[torchrec.distributed.planner.types.ShardingOption], storage_constraint: torchrec.distributed.planner.types.Topology) List[torchrec.distributed.planner.types.ShardingOption]

Places sharding options on topology based on each sharding option’s partition_by attribute. The topology, storage, and perfs are updated at the end of the placement.

Parameters
  • proposal (List[ShardingOption]) – list of populated sharding options.

  • storage_constraint (Topology) – device topology.

Returns

list of sharding options for selected plan.

Return type

List[ShardingOption]

Example:

sharding_options = [
        ShardingOption(partition_by="uniform",
                shards=[
                    Shards(storage=1, perf=1),
                    Shards(storage=1, perf=1),
                ]),
        ShardingOption(partition_by="uniform",
                shards=[
                    Shards(storage=2, perf=2),
                    Shards(storage=2, perf=2),
                ]),
        ShardingOption(partition_by="device",
                shards=[
                    Shards(storage=3, perf=3),
                    Shards(storage=3, perf=3),
                ])
        ShardingOption(partition_by="device",
                shards=[
                    Shards(storage=4, perf=4),
                    Shards(storage=4, perf=4),
                ]),
    ]
topology = Topology(world_size=2)

# First [sharding_options[0] and sharding_options[1]] will be placed on the
# topology with the uniform strategy, resulting in

topology.devices[0].perf = (1,2)
topology.devices[1].perf = (1,2)

# Finally sharding_options[2] and sharding_options[3]] will be placed on the
# topology with the device strategy (see docstring of `partition_by_device` for
# more details).

topology.devices[0].perf = (1,2) + (3,4)
topology.devices[1].perf = (1,2) + (3,4)

# The topology updates are done after the end of all the placements (the other
# in the example is just for clarity).
class torchrec.distributed.planner.partitioners.ShardingOptionGroup(sharding_options: List[torchrec.distributed.planner.types.ShardingOption], storage_sum: torchrec.distributed.planner.types.Storage)

Bases: object

sharding_options: List[torchrec.distributed.planner.types.ShardingOption]
storage_sum: torchrec.distributed.planner.types.Storage

torchrec.distributed.planner.perf_models

class torchrec.distributed.planner.perf_models.NoopPerfModel(topology: torchrec.distributed.planner.types.Topology)

Bases: torchrec.distributed.planner.types.PerfModel

rate(plan: List[torchrec.distributed.planner.types.ShardingOption]) float

torchrec.distributed.planner.planners

class torchrec.distributed.planner.planners.EmbeddingShardingPlanner(topology: torchrec.distributed.planner.types.Topology, enumerator: Optional[torchrec.distributed.planner.types.Enumerator] = None, storage_reservation: Optional[torchrec.distributed.planner.types.StorageReservation] = None, proposer: Optional[Union[torchrec.distributed.planner.types.Proposer, List[torchrec.distributed.planner.types.Proposer]]] = None, partitioner: Optional[torchrec.distributed.planner.types.Partitioner] = None, performance_model: Optional[torchrec.distributed.planner.types.PerfModel] = None, stats: Optional[torchrec.distributed.planner.types.Stats] = None, constraints: Optional[Dict[str, torchrec.distributed.planner.types.ParameterConstraints]] = None, debug: bool = False)

Bases: torchrec.distributed.types.ShardingPlanner

Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints.

collective_plan(module: torch.nn.modules.module.Module, sharders: List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]], pg: torch._C._distributed_c10d.ProcessGroup) torchrec.distributed.types.ShardingPlan

Call self.plan(…) on rank 0 and broadcast

plan(module: torch.nn.modules.module.Module, sharders: List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]) torchrec.distributed.types.ShardingPlan

Plans sharding for provided module and given sharders.

Parameters
  • module (nn.Module) – module that sharding is planned for.

  • sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.

Returns

the computed sharding plan.

Return type

ShardingPlan

torchrec.distributed.planner.proposers

class torchrec.distributed.planner.proposers.GreedyProposer(use_depth: bool = True)

Bases: torchrec.distributed.planner.types.Proposer

Proposes sharding plans in greedy fashion.

Sorts sharding options for each shardable parameter by perf. On each iteration, finds parameter with largest current storage usage and tries its next sharding option.

feedback(partitionable: bool, plan: Optional[List[torchrec.distributed.planner.types.ShardingOption]] = None, perf_rating: Optional[float] = None) None
load(search_space: List[torchrec.distributed.planner.types.ShardingOption]) None
propose() Optional[List[torchrec.distributed.planner.types.ShardingOption]]
class torchrec.distributed.planner.proposers.GridSearchProposer(max_proposals: int = 10000)

Bases: torchrec.distributed.planner.types.Proposer

feedback(partitionable: bool, plan: Optional[List[torchrec.distributed.planner.types.ShardingOption]] = None, perf_rating: Optional[float] = None) None
load(search_space: List[torchrec.distributed.planner.types.ShardingOption]) None
propose() Optional[List[torchrec.distributed.planner.types.ShardingOption]]
class torchrec.distributed.planner.proposers.UniformProposer(use_depth: bool = True)

Bases: torchrec.distributed.planner.types.Proposer

Proposes uniform sharding plans, plans that have the same sharding type for all sharding options.

feedback(partitionable: bool, plan: Optional[List[torchrec.distributed.planner.types.ShardingOption]] = None, perf_rating: Optional[float] = None) None
load(search_space: List[torchrec.distributed.planner.types.ShardingOption]) None
propose() Optional[List[torchrec.distributed.planner.types.ShardingOption]]

torchrec.distributed.planner.shard_estimators

class torchrec.distributed.planner.shard_estimators.EmbeddingPerfEstimator(topology: torchrec.distributed.planner.types.Topology, constraints: Optional[Dict[str, torchrec.distributed.planner.types.ParameterConstraints]] = None)

Bases: torchrec.distributed.planner.types.ShardEstimator

Embedding Wall Time Perf Estimator

estimate(sharding_options: List[torchrec.distributed.planner.types.ShardingOption], sharder_map: Optional[Dict[str, torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]] = None) None
class torchrec.distributed.planner.shard_estimators.EmbeddingStorageEstimator(topology: torchrec.distributed.planner.types.Topology, constraints: Optional[Dict[str, torchrec.distributed.planner.types.ParameterConstraints]] = None)

Bases: torchrec.distributed.planner.types.ShardEstimator

Embedding Storage Usage Estimator

estimate(sharding_options: List[torchrec.distributed.planner.types.ShardingOption], sharder_map: Optional[Dict[str, torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]] = None) None
torchrec.distributed.planner.shard_estimators.calculate_shard_storages(sharder: torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module], sharding_type: str, tensor: torch.Tensor, compute_device: str, compute_kernel: str, shard_sizes: List[List[int]], batch_size: int, world_size: int, local_world_size: int, input_lengths: List[float], caching_ratio: float, is_pooled: bool) List[torchrec.distributed.planner.types.Storage]

Calculates estimated storage sizes for each sharded tensor, comprised of input, output, tensor, gradient, and optimizer sizes.

Parameters
  • sharder (ModuleSharder[nn.Module]) – sharder for module that supports sharding.

  • sharding_type (str) – provided ShardingType value.

  • tensor (torch.Tensor) – tensor to be sharded.

  • compute_device (str) – compute device to be used.

  • compute_kernel (str) – compute kernel to be used.

  • shard_sizes (List[List[int]]) – list of dimensions of each sharded tensor.

  • batch_size (int) – batch size to be used.

  • world_size (int) – total number of devices in topology.

  • local_world_size (int) – total number of devices in host group topology.

  • input_lengths (List[float]) – average input lengths synonymous with pooling factors.

  • caching_ratio (float) – ratio of HBM to DDR memory for UVM caching.

  • is_pooled (bool) – True if embedding output is pooled (ie. EmbeddingBag), False if unpooled/sequential (ie. Embedding).

Returns

storage object for each device in topology.

Return type

List[Storage]

torchrec.distributed.planner.shard_estimators.perf_func_emb_wall_time(shard_sizes: List[List[int]], compute_kernel: str, compute_device: str, sharding_type: str, batch_size: int, world_size: int, local_world_size: int, input_lengths: List[float], input_data_type_size: float, output_data_type_size: float, bw_intra_host: float, bw_inter_host: float, is_pooled: bool, is_weighted: bool = False, has_feature_processor: bool = False, caching_ratio: Optional[float] = None) List[float]

Attempts to model perfs as a function of relative wall times.

Parameters
  • shard_sizes (List[List[int]]) – the list of (local_rows, local_cols) of each shard.

  • compute_kernel (str) – compute kernel.

  • compute_device (str) – compute device.

  • sharding_type (str) – tw, rw, cw, twrw, dp.

  • batch_size (int) – the size of each batch.

  • world_size (int) – the number of devices for all hosts.

  • local_world_size (int) – the number of the device for each host.

  • input_lengths (List[float]) – the list of the average number of lookups of each input query feature.

  • input_data_type_size (float) – the data type size of the distributed data_parallel input.

  • output_data_type_size (float) – the data type size of the distributed data_parallel output.

  • bw_intra_host (float) – the bandwidth within a single host like multiple threads.

  • bw_inter_host (float) – the bandwidth between two hosts like multiple machines.

  • is_pooled (bool) – True if embedding output is pooled (ie. EmbeddingBag), False if unpooled/sequential (ie. Embedding).

  • is_weighted (bool = False) – if the module is an EBC and is weighted, typically signifying an id score list feature.

  • has_feature_processor (bool = False) – if the module has a feature processor.

  • caching_ratio (Optional[float] = None) – cache ratio to determine the bandwidth of device.

Returns

the list of perf for each shard.

Return type

List[float]

torchrec.distributed.planner.stats

class torchrec.distributed.planner.stats.EmbeddingStats

Bases: torchrec.distributed.planner.types.Stats

Stats for a sharding planner execution.

log(sharding_plan: torchrec.distributed.types.ShardingPlan, topology: torchrec.distributed.planner.types.Topology, num_proposals: int, num_plans: int, best_plan: List[torchrec.distributed.planner.types.ShardingOption], debug: bool = False) None

Logs stats for a given sharding plan to stdout.

Provides a tabular view of stats for the given sharding plan with per device storage usage (HBM and DDR), perf, input, output, and number/type of shards.

Parameters
  • sharding_plan (ShardingPlan) – sharding plan chosen by the planner.

  • topology (Topology) – device topology.

  • num_proposals (int) – number of proposals evaluated.

  • num_plans (int) – number of proposals successfully partitioned.

  • best_plan (List[ShardingOption]) – plan with expected performance.

  • debug (bool) – whether to enable debug mode.

torchrec.distributed.planner.storage_reservations

class torchrec.distributed.planner.storage_reservations.FixedPercentageReservation(percentage: float)

Bases: torchrec.distributed.planner.types.StorageReservation

reserve(topology: torchrec.distributed.planner.types.Topology, module: torch.nn.modules.module.Module, sharders: List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]], constraints: Optional[Dict[str, torchrec.distributed.planner.types.ParameterConstraints]] = None) torchrec.distributed.planner.types.Topology
class torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation(percentage: float)

Bases: torchrec.distributed.planner.types.StorageReservation

Reserves storage for model to be sharded with heuristical calculation. The storage reservation is comprised of unshardable tensor storage, KJT storage, and an extra percentage.

Parameters

percentage (float) – extra storage percentage to reserve that acts as a margin of error beyond heuristic calculation of storage.

reserve(topology: torchrec.distributed.planner.types.Topology, module: torch.nn.modules.module.Module, sharders: List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]], constraints: Optional[Dict[str, torchrec.distributed.planner.types.ParameterConstraints]] = None) torchrec.distributed.planner.types.Topology

torchrec.distributed.planner.types

class torchrec.distributed.planner.types.DeviceHardware(rank: int, storage: torchrec.distributed.planner.types.Storage, perf: float = 0)

Bases: object

Representation of a device in a process group. ‘perf’ is an estimation of network, CPU, and storage usages.

perf: float = 0
rank: int
storage: torchrec.distributed.planner.types.Storage
class torchrec.distributed.planner.types.Enumerator(topology: torchrec.distributed.planner.types.Topology, constraints: Optional[Dict[str, torchrec.distributed.planner.types.ParameterConstraints]] = None, estimator: Optional[Union[torchrec.distributed.planner.types.ShardEstimator, List[torchrec.distributed.planner.types.ShardEstimator]]] = None)

Bases: abc.ABC

Generates all relevant sharding options for given topology, constraints, nn.Module, and sharders.

abstract enumerate(module: torch.nn.modules.module.Module, sharders: List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]) List[torchrec.distributed.planner.types.ShardingOption]

See class description.

class torchrec.distributed.planner.types.ParameterConstraints(sharding_types: typing.Optional[typing.List[str]] = None, compute_kernels: typing.Optional[typing.List[str]] = None, min_partition: typing.Optional[int] = None, caching_ratio: typing.Optional[float] = None, pooling_factors: typing.List[float] = <factory>)

Bases: object

Stores user provided constraints around the sharding plan.

caching_ratio: Optional[float] = None
compute_kernels: Optional[List[str]] = None
min_partition: Optional[int] = None
pooling_factors: List[float]
sharding_types: Optional[List[str]] = None
class torchrec.distributed.planner.types.PartitionByType(value)

Bases: enum.Enum

Well-known partition types.

DEVICE = 'device'
HOST = 'host'
UNIFORM = 'uniform'
class torchrec.distributed.planner.types.Partitioner

Bases: abc.ABC

Partitions shards.

Today we have multiple strategies ie. (Greedy, BLDM, Linear).

abstract partition(proposal: List[torchrec.distributed.planner.types.ShardingOption], storage_constraint: torchrec.distributed.planner.types.Topology) List[torchrec.distributed.planner.types.ShardingOption]
class torchrec.distributed.planner.types.PerfModel

Bases: abc.ABC

abstract rate(plan: List[torchrec.distributed.planner.types.ShardingOption]) float
exception torchrec.distributed.planner.types.PlannerError

Bases: Exception

class torchrec.distributed.planner.types.Proposer

Bases: abc.ABC

Prosposes complete lists of sharding options which can be parititioned to generate a plan.

abstract feedback(partitionable: bool, plan: Optional[List[torchrec.distributed.planner.types.ShardingOption]] = None, perf_rating: Optional[float] = None) None
abstract load(search_space: List[torchrec.distributed.planner.types.ShardingOption]) None
abstract propose() Optional[List[torchrec.distributed.planner.types.ShardingOption]]
class torchrec.distributed.planner.types.Shard(size: List[int], offset: List[int], storage: Optional[torchrec.distributed.planner.types.Storage] = None, perf: Optional[float] = None, rank: Optional[int] = None)

Bases: object

Representation of a subset of an embedding table. ‘size’ and ‘offset’ fully determine the tensors in the shard. ‘storage’ is an estimation of how much it takes to store the shard with an estimation ‘perf’.

offset: List[int]
perf: Optional[float] = None
rank: Optional[int] = None
size: List[int]
storage: Optional[torchrec.distributed.planner.types.Storage] = None
class torchrec.distributed.planner.types.ShardEstimator(topology: torchrec.distributed.planner.types.Topology, constraints: Optional[Dict[str, torchrec.distributed.planner.types.ParameterConstraints]] = None)

Bases: abc.ABC

Estimates shard perf or storage, requires fully specified sharding options.

abstract estimate(sharding_options: List[torchrec.distributed.planner.types.ShardingOption], sharder_map: Optional[Dict[str, torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]] = None) None
class torchrec.distributed.planner.types.ShardingOption(name: str, tensor: torch.Tensor, module: typing.Tuple[str, torch.nn.modules.module.Module], upstream_modules: typing.List[typing.Tuple[str, torch.nn.modules.module.Module]], downstream_modules: typing.List[typing.Tuple[str, torch.nn.modules.module.Module]], input_lengths: typing.List[float], batch_size: int, sharding_type: str, partition_by: str, compute_kernel: str, shards: typing.List[torchrec.distributed.planner.types.Shard] = <factory>, dependency: typing.Optional[str] = None)

Bases: object

One way of sharding an embedding table.

batch_size: int
compute_kernel: str
dependency: Optional[str] = None
downstream_modules: List[Tuple[str, torch.nn.modules.module.Module]]
property fqn: str
input_lengths: List[float]
property is_pooled: bool
module: Tuple[str, torch.nn.modules.module.Module]
name: str
property num_inputs: int
property num_shards: int
partition_by: str
property path: str
sharding_type: str
shards: List[torchrec.distributed.planner.types.Shard]
tensor: torch.Tensor
property total_storage: torchrec.distributed.planner.types.Storage
upstream_modules: List[Tuple[str, torch.nn.modules.module.Module]]
class torchrec.distributed.planner.types.Stats

Bases: abc.ABC

Logs statistics related to the sharding plan.

abstract log(sharding_plan: torchrec.distributed.types.ShardingPlan, topology: torchrec.distributed.planner.types.Topology, num_proposals: int, num_plans: int, best_plan: List[torchrec.distributed.planner.types.ShardingOption], debug: bool = False) None

See class description

class torchrec.distributed.planner.types.Storage(hbm: int, ddr: int)

Bases: object

Representation of the storage capacities of a hardware used in training.

ddr: int
hbm: int
class torchrec.distributed.planner.types.StorageReservation

Bases: abc.ABC

Reserves storage space for non-sharded parts of the model.

abstract reserve(topology: torchrec.distributed.planner.types.Topology, module: torch.nn.modules.module.Module, sharders: List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]], constraints: Optional[Dict[str, torchrec.distributed.planner.types.ParameterConstraints]] = None) torchrec.distributed.planner.types.Topology
class torchrec.distributed.planner.types.Topology(world_size: int, compute_device: str, hbm_cap: Optional[int] = None, ddr_cap: int = 137438953472, local_world_size: Optional[int] = None, intra_host_bw: float = 644245094.4, inter_host_bw: float = 13421772.8, batch_size: int = 512)

Bases: object

property batch_size: int
property compute_device: str
property devices: List[torchrec.distributed.planner.types.DeviceHardware]
property inter_host_bw: float
property intra_host_bw: float
property local_world_size: int
property world_size: int

torchrec.distributed.planner.utils

torchrec.distributed.planner.utils.bytes_to_gb(num_bytes: int) float
torchrec.distributed.planner.utils.bytes_to_mb(num_bytes: Union[float, int]) float
torchrec.distributed.planner.utils.gb_to_bytes(gb: float) int
torchrec.distributed.planner.utils.prod(iterable: Iterable[int]) int
torchrec.distributed.planner.utils.sharder_name(t: Type[Any]) str

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources