Shortcuts

Inference

TorchRec provides easy-to-use APIs for transforming an authored TorchRec model into an optimized inference model for distributed inference, via eager module swaps.

This transforms TorchRec modules like EmbeddingBagCollection in the model to a quantized, sharded version that can be compiled using torch.fx and TorchScript for inference in a C++ environment.

The intended use is calling quantize_inference_model on the model followed by shard_quant_model.

torchrec.inference.modules.quantize_inference_model(model: Module, quantization_mapping: Optional[Dict[str, Type[Module]]] = None, per_table_weight_dtype: Optional[Dict[str, dtype]] = None, fp_weight_dtype: dtype = torch.int8, quantization_dtype: dtype = torch.int8, output_dtype: dtype = torch.float32) Module

Quantize the model, module swapping TorchRec train modules with its quantized counterpart, (e.g. EmbeddingBagCollection -> QuantEmbeddingBagCollection).

Parameters:
  • model (torch.nn.Module) – the model to be quantized

  • quantization_mapping (Optional[Dict[str, Type[torch.nn.Module]]]) – a mapping from the original module type to the quantized module type. If not provided, the default mapping will be used: (EmbeddingBagCollection -> QuantEmbeddingBagCollection, EmbeddingCollection -> QuantEmbeddingCollection).

  • per_table_weight_dtype (Optional[Dict[str, torch.dtype]]) – a mapping from table name to weight dtype. If not provided, the default quantization dtype will be used (int8).

  • fp_weight_dtype (torch.dtype) – the desired quantized dtype for feature processor weights in FeatureProcessedEmbeddingBagCollection if used. Default is int8.

Returns:

the quantized model

Return type:

torch.nn.Module

Example:

ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))

module = DLRMPredictModule(
    embedding_bag_collection=ebc,
    dense_in_features=self.model_config.dense_in_features,
    dense_arch_layer_sizes=self.model_config.dense_arch_layer_sizes,
    over_arch_layer_sizes=self.model_config.over_arch_layer_sizes,
    id_list_features_keys=self.model_config.id_list_features_keys,
    dense_device=device,
)

quant_model = quantize_inference_model(module)
torchrec.inference.modules.shard_quant_model(model: Module, world_size: int = 1, compute_device: str = 'cuda', sharding_device: str = 'meta', sharders: Optional[List[ModuleSharder[Module]]] = None, device_memory_size: Optional[int] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None) Tuple[Module, ShardingPlan]

Shard a quantized TorchRec model, used for generating the most optimal model for inference and necessary for distributed inference.

Parameters:
  • model (torch.nn.Module) – the quantized model to be sharded

  • world_size (int) – the number of devices to shard the model, default to 1

  • compute_device (str) – the device to run the model, default to “cuda”

  • sharding_device (str) – the device to run the sharding, default to “meta”

  • sharders (Optional[List[ModuleSharder[torch.nn.Module]]]) – sharders to use for sharding quantized model, default to QuantEmbeddingBagCollectionSharder, QuantEmbeddingCollectionSharder, QuantFeatureProcessedEmbeddingBagCollectionSharder.

  • device_memory_size (Optional[int]) – the memory limit for cuda devices, default to None

  • constraints (Optional[Dict[str, ParameterConstraints]]) – constraints to use for sharding, default to None which will then implement default constraints with QuantEmbeddingBagCollection being sharded TableWise

Returns:

the sharded model and the sharding plan

Return type:

Tuple[torch.nn.Module, ShardingPlan]

Example::

ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device(“meta”))

module = DLRMPredictModule(

embedding_bag_collection=ebc, dense_in_features=self.model_config.dense_in_features, dense_arch_layer_sizes=self.model_config.dense_arch_layer_sizes, over_arch_layer_sizes=self.model_config.over_arch_layer_sizes, id_list_features_keys=self.model_config.id_list_features_keys, dense_device=device,

)

quant_model = quantize_inference_model(module) sharded_model, _ = shard_quant_model(quant_model)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources