Shortcuts

torchrec.quant

Torchrec Quantization

Torchrec provides a quantized version of EmbeddingBagCollection for inference. It relies on fbgemm quantized ops. This reduces the size of the model weights and speeds up model execution.

Example

>>> import torch.quantization as quant
>>> import torchrec.quant as trec_quant
>>> import torchrec as trec
>>> qconfig = quant.QConfig(
>>>     activation=quant.PlaceholderObserver,
>>>     weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8),
>>> )
>>> quantized = quant.quantize_dynamic(
>>>     module,
>>>     qconfig_spec={
>>>         trec.EmbeddingBagCollection: qconfig,
>>>     },
>>>     mapping={
>>>         trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
>>>     },
>>>     inplace=inplace,
>>> )

torchrec.quant.embedding_modules

class torchrec.quant.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool, device: device, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16)

Bases: EmbeddingBagCollectionInterface, ModuleNoCopyMixin

EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags). This EmbeddingBagCollection is quantized for lower precision. It relies on fbgemm quantized ops and provides table batching.

It processes sparse data in the form of KeyedJaggedTensor with values of the form [F X B X L] F: features (keys) B: batch size L: Length of sparse features (jagged)

and outputs a KeyedTensor with values of the form [B * (F * D)] where F: features (keys) D: each feature’s (key’s) embedding dimension B: batch size

Parameters:
  • table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]]) – map of tables to quantized weights

  • embedding_configs (List[EmbeddingBagConfig]) – list of embedding tables

  • is_weighted – (bool): whether input KeyedJaggedTensor is weighted

  • device – (Optional[torch.device]): default compute device

Call Args:

features: KeyedJaggedTensor,

Returns:

KeyedTensor

Example:

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])

#        0       1        2  <-- batch
# "f1"   [0,1] None    [2]
# "f2"   [3]    [4]    [5,6,7]
#  ^
# feature
features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

ebc.qconfig = torch.quantization.QConfig(
    activation=torch.quantization.PlaceholderObserver.with_args(
        dtype=torch.qint8
    ),
    weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8),
)

qebc = QuantEmbeddingBagCollection.from_float(ebc)
quantized_embeddings = qebc(features)
property device: device
embedding_bag_configs() List[EmbeddingBagConfig]
forward(features: KeyedJaggedTensor) KeyedTensor
Parameters:

features (KeyedJaggedTensor) – KJT of form [F X B X L].

Returns:

KeyedTensor

classmethod from_float(module: EmbeddingBagCollection) EmbeddingBagCollection
is_weighted() bool
output_dtype() dtype
training: bool
class torchrec.quant.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: device, need_indices: bool = False, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16)

Bases: EmbeddingCollectionInterface, ModuleNoCopyMixin

EmbeddingCollection represents a collection of non-pooled embeddings.

It processes sparse data in the form of KeyedJaggedTensor of the form [F X B X L] where:

  • F: features (keys)

  • B: batch size

  • L: length of sparse features (variable)

and outputs Dict[feature (key), JaggedTensor]. Each JaggedTensor contains values of the form (B * L) X D where:

  • B: batch size

  • L: length of sparse features (jagged)

  • D: each feature’s (key’s) embedding dimension and lengths are of the form L

Parameters:
  • tables (List[EmbeddingConfig]) – list of embedding tables.

  • device (Optional[torch.device]) – default compute device.

  • need_indices (bool) – if we need to pass indices to the final lookup result dict

Example:

e1_config = EmbeddingConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
    name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
)

ec = EmbeddingCollection(tables=[e1_config, e2_config])

#     0       1        2  <-- batch
# 0   [0,1] None    [2]
# 1   [3]    [4]    [5,6,7]
# ^
# feature

features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)
feature_embeddings = ec(features)
print(feature_embeddings['f2'].values())
tensor([[-0.2050,  0.5478,  0.6054],
[ 0.7352,  0.3210, -3.0399],
[ 0.1279, -0.1756, -0.4130],
[ 0.7519, -0.4341, -0.0499],
[ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)
property device: device
embedding_configs() List[EmbeddingConfig]
embedding_dim() int
embedding_names_by_table() List[List[str]]
forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor]
Parameters:

features (KeyedJaggedTensor) – KJT of form [F X B X L].

Returns:

Dict[str, JaggedTensor]

classmethod from_float(module: EmbeddingCollection) EmbeddingCollection
need_indices() bool
output_dtype() dtype
training: bool
class torchrec.quant.embedding_modules.FeatureProcessedEmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool, device: device, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16, feature_processor: Optional[FeatureProcessorsCollection] = None)

Bases: EmbeddingBagCollection

embedding_bags: nn.ModuleDict
forward(features: KeyedJaggedTensor) KeyedTensor
Parameters:

features (KeyedJaggedTensor) – KJT of form [F X B X L].

Returns:

KeyedTensor

classmethod from_float(module: FeatureProcessedEmbeddingBagCollection) FeatureProcessedEmbeddingBagCollection
tbes: torch.nn.ModuleList
training: bool
torchrec.quant.embedding_modules.for_each_module_of_type_do(module: Module, module_types: List[Type[Module]], op: Callable[[Module], None]) None
torchrec.quant.embedding_modules.pruned_num_embeddings(pruning_indices_mapping: Tensor) int
torchrec.quant.embedding_modules.quant_prep_customize_row_alignment(module: Module, module_types: List[Type[Module]], row_alignment: int) None
torchrec.quant.embedding_modules.quant_prep_enable_quant_state_dict_split_scale_bias(module: Module) None
torchrec.quant.embedding_modules.quant_prep_enable_quant_state_dict_split_scale_bias_for_types(module: Module, module_types: List[Type[Module]]) None
torchrec.quant.embedding_modules.quant_prep_enable_register_tbes(module: Module, module_types: List[Type[Module]]) None
torchrec.quant.embedding_modules.quantize_state_dict(module: Module, table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]], table_name_to_data_type: Dict[str, DataType], table_name_to_pruning_indices_mapping: Optional[Dict[str, Tensor]] = None) device

Module contents

Torchrec Quantization

Torchrec provides a quantized version of EmbeddingBagCollection for inference. It relies on fbgemm quantized ops. This reduces the size of the model weights and speeds up model execution.

Example

>>> import torch.quantization as quant
>>> import torchrec.quant as trec_quant
>>> import torchrec as trec
>>> qconfig = quant.QConfig(
>>>     activation=quant.PlaceholderObserver,
>>>     weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8),
>>> )
>>> quantized = quant.quantize_dynamic(
>>>     module,
>>>     qconfig_spec={
>>>         trec.EmbeddingBagCollection: qconfig,
>>>     },
>>>     mapping={
>>>         trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
>>>     },
>>>     inplace=inplace,
>>> )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources