Shortcuts

torchrec.quant

Torchrec Quantization

Torchrec provides a quantized version of EmbeddingBagCollection for inference. It relies on fbgemm quantized ops. This reduces the size of the model weights and speeds up model execution.

Example

>>> import torch.quantization as quant
>>> import torchrec.quant as trec_quant
>>> import torchrec as trec
>>> qconfig = quant.QConfig(
>>>     activation=quant.PlaceholderObserver,
>>>     weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8),
>>> )
>>> quantized = quant.quantize_dynamic(
>>>     module,
>>>     qconfig_spec={
>>>         trec.EmbeddingBagCollection: qconfig,
>>>     },
>>>     mapping={
>>>         trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
>>>     },
>>>     inplace=inplace,
>>> )

torchrec.quant.embedding_modules

class torchrec.quant.embedding_modules.EmbeddingBagCollection(table_name_to_quantized_weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]], embedding_configs: List[torchrec.modules.embedding_configs.EmbeddingBagConfig], is_weighted: bool, device: torch.device)

Bases: torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface

EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags). This EmbeddingBagCollection is quantized for lower precision. It relies on fbgemm quantized ops

It processes sparse data in the form of KeyedJaggedTensor with values of the form [F X B X L] F: features (keys) B: batch size L: Length of sparse features (jagged)

and outputs a KeyedTensor with values of the form [B * (F * D)] where F: features (keys) D: each feature’s (key’s) embedding dimension B: batch size

Parameters
  • table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]]) – map of tables to quantized weights

  • embedding_configs (List[EmbeddingBagConfig]) – list of embedding tables

  • is_weighted – (bool): whether input KeyedJaggedTensor is weighted

  • device – (Optional[torch.device]): default compute device

Call Args:

features: KeyedJaggedTensor,

Returns

KeyedTensor

Example:

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])

#        0       1        2  <-- batch
# "f1"   [0,1] None    [2]
# "f2"   [3]    [4]    [5,6,7]
#  ^
# feature
features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

ebc.qconfig = torch.quantization.QConfig(
    activation=torch.quantization.PlaceholderObserver.with_args(
        dtype=torch.qint8
    ),
    weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8),
)

qebc = QuantEmbeddingBagCollection.from_float(ebc)
quantized_embeddings = qebc(features)
property embedding_bag_configs: List[torchrec.modules.embedding_configs.EmbeddingBagConfig]
forward(features: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torchrec.sparse.jagged_tensor.KeyedTensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

classmethod from_float(module: torchrec.modules.embedding_modules.EmbeddingBagCollection) torchrec.quant.embedding_modules.EmbeddingBagCollection
property is_weighted: bool
named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters
  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.

Yields

(string, torch.Tensor) – Tuple containing the name and buffer

Example:

>>> for name, buf in self.named_buffers():
>>>    if name in ['running_var']:
>>>        print(buf.size())
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns

a dictionary containing a whole state of the module

Return type

dict

Example:

>>> module.state_dict().keys()
['bias', 'weight']
training: bool
torchrec.quant.embedding_modules.quantize_state_dict(module: torch.nn.modules.module.Module, table_name_to_quantized_weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]], data_type: torchrec.modules.embedding_configs.DataType) torch.device

Module contents

Torchrec Quantization

Torchrec provides a quantized version of EmbeddingBagCollection for inference. It relies on fbgemm quantized ops. This reduces the size of the model weights and speeds up model execution.

Example

>>> import torch.quantization as quant
>>> import torchrec.quant as trec_quant
>>> import torchrec as trec
>>> qconfig = quant.QConfig(
>>>     activation=quant.PlaceholderObserver,
>>>     weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8),
>>> )
>>> quantized = quant.quantize_dynamic(
>>>     module,
>>>     qconfig_spec={
>>>         trec.EmbeddingBagCollection: qconfig,
>>>     },
>>>     mapping={
>>>         trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
>>>     },
>>>     inplace=inplace,
>>> )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources