Source code for torch.ao.nn.quantized.modules.embedding_ops
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
from torch import Tensor # noqa: F401
from torch._jit_internal import List, Optional # noqa: F401
from .utils import _hide_packed_params_repr, _quantize_weight
__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]
class EmbeddingPackedParams(torch.nn.Module):
_version = 1
def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
super().__init__()
self.dtype = dtype
if self.dtype in [torch.quint8, torch.quint4x2]:
scales = torch.ones(num_embeddings, dtype=torch.float)
zero_points = torch.zeros(num_embeddings, dtype=torch.float)
wq = torch._empty_per_channel_affine_quantized(
[num_embeddings, embedding_dim],
scales=scales,
zero_points=zero_points,
axis=0,
dtype=self.dtype,
)
self.set_weight(wq)
else:
raise NotImplementedError(
f"Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}"
)
@torch.jit.export
def set_weight(self, weight: torch.Tensor) -> None:
if self.dtype in [torch.quint8, torch.quint4x2]:
self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
else:
raise NotImplementedError(
"Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2."
)
@torch.jit.export
def _weight(self):
if self.dtype in [torch.quint8, torch.quint4x2]:
return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
else:
raise NotImplementedError(
"Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2."
)
def forward(self, x):
return x
# Version 1
# self
# |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase
# |--- dtype : torch.dtype
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + "dtype"] = self.dtype
destination[prefix + "_packed_weight"] = self._weight()
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.dtype = state_dict[prefix + "dtype"]
state_dict.pop(prefix + "dtype")
weight = state_dict[prefix + "_packed_weight"]
state_dict.pop(prefix + "_packed_weight")
self.set_weight(weight)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
def __repr__(self):
return self._weight().__repr__()
[docs]class Embedding(torch.nn.Module):
r"""
A quantized Embedding module with quantized packed weights as inputs.
We adopt the same interface as `torch.nn.Embedding`, please see
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html for documentation.
Similar to :class:`~torch.nn.Embedding`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module of
shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
Examples::
>>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12)
>>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8])
>>> output = m(indices)
>>> print(output.size())
torch.Size([9, 12])
"""
_version = 1
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
dtype=torch.quint8,
) -> None:
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.dtype = dtype
if _weight is None:
scales = torch.ones(num_embeddings, dtype=torch.float)
zero_points = torch.zeros(num_embeddings, dtype=torch.float)
qweight = torch._empty_per_channel_affine_quantized(
[num_embeddings, embedding_dim],
scales=scales,
zero_points=zero_points,
axis=0,
dtype=torch.quint8,
)
else:
assert list(_weight.shape) == [
num_embeddings,
embedding_dim,
], "Shape of weight does not match num_embeddings and embedding_dim"
qweight = _weight
self._packed_params = EmbeddingPackedParams(
num_embeddings, embedding_dim, dtype
)
self._packed_params.set_weight(qweight)
def forward(self, indices: Tensor) -> Tensor:
if self.dtype == torch.quint4x2:
return torch.ops.quantized.embedding_4bit(
self._packed_params._packed_weight, indices
)
else:
return torch.ops.quantized.embedding_byte(
self._packed_params._packed_weight, indices
)
def _get_name(self):
return "QuantizedEmbedding"
def __repr__(self):
return _hide_packed_params_repr(self, EmbeddingPackedParams)
def extra_repr(self):
extra_repr_str = (
f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, "
f"dtype={self._packed_params.dtype}, qscheme={self.weight().qscheme()}"
)
return extra_repr_str
def set_weight(self, w: torch.Tensor) -> None:
self._packed_params.set_weight(w)
def weight(self):
return self._packed_params._weight()
[docs] @classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized embedding module from a float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by user
"""
if hasattr(mod, "weight_fake_quant"):
assert type(mod) == torch.ao.nn.qat.Embedding, (
"nnq."
+ cls.__name__
+ ".from_float "
+ "with fake quant only works for "
+ torch.ao.nn.qat.Embedding.__name__
)
weight_observer = mod.weight_fake_quant
activation_post_process = mod.activation_post_process
else:
assert type(mod) == nn.Embedding, (
"nnq."
+ cls.__name__
+ ".from_float only works for "
+ nn.Embedding.__name__
)
assert hasattr(
mod, "qconfig"
), "Embedding input float module must have qconfig defined"
from torch.ao.quantization import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr]
weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator]
else:
weight_observer = float_qparams_weight_only_qconfig.weight()
dtype = weight_observer.dtype
is_float_qparams_qconfig = (
weight_observer.qscheme == torch.per_channel_affine_float_qparams
)
assert (
is_float_qparams_qconfig
), "Embedding quantization is only supported with float_qparams_weight_only_qconfig."
assert (
dtype == torch.quint8 or dtype == torch.quint4x2
), f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}"
# Run the observer to calculate qparams.
weight_observer(mod.weight)
qweight = _quantize_weight(mod.weight.float(), weight_observer)
# Create quantized Embedding module and pass in the quantized weight
qembedding = Embedding(mod.num_embeddings, mod.embedding_dim)
qembedding.set_weight(qweight)
return qembedding
@classmethod
def from_reference(cls, ref_embedding):
qembedding = cls(
ref_embedding.num_embeddings,
ref_embedding.embedding_dim,
ref_embedding.padding_idx,
ref_embedding.max_norm,
ref_embedding.norm_type,
ref_embedding.scale_grad_by_freq,
ref_embedding.sparse,
ref_embedding.get_quantized_weight(),
ref_embedding.weight_dtype,
)
return qembedding
[docs]class EmbeddingBag(Embedding):
r"""
A quantized EmbeddingBag module with quantized packed weights as inputs.
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html for documentation.
Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module of
shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
Examples::
>>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum')
>>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
>>> offsets = torch.tensor([0, 19, 20, 28, 28, 32])
>>> output = m(indices, offsets)
>>> print(output.size())
torch.Size([5, 12])
"""
_version = 1
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
mode: str = "sum",
sparse: bool = False,
_weight: Optional[Tensor] = None,
include_last_offset: bool = False,
dtype=torch.quint8,
) -> None:
super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype)
self.mode = mode
self.pruned_weights = False
self.include_last_offset = include_last_offset
self.dtype = dtype
def forward(
self,
indices: Tensor,
offsets: Optional[Tensor] = None,
per_sample_weights: Optional[Tensor] = None,
compressed_indices_mapping: Optional[Tensor] = None,
) -> Tensor:
if self.dtype == torch.quint4x2:
return torch.ops.quantized.embedding_bag_4bit(
self._packed_params._packed_weight,
indices,
offsets,
False,
0,
self.pruned_weights,
per_sample_weights,
compressed_indices_mapping,
self.include_last_offset,
)
else:
return torch.ops.quantized.embedding_bag_byte(
self._packed_params._packed_weight,
indices,
offsets,
False,
0,
self.pruned_weights,
per_sample_weights,
compressed_indices_mapping,
self.include_last_offset,
)
def _get_name(self):
return "QuantizedEmbeddingBag"
[docs] @classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized embedding_bag module from a float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by user
"""
if hasattr(mod, "weight_fake_quant"):
weight_observer = mod.weight_fake_quant
else:
assert type(mod) == nn.EmbeddingBag, (
"nnq."
+ cls.__name__
+ ".from_float only works for "
+ nn.EmbeddingBag.__name__
)
assert hasattr(
mod, "qconfig"
), "EmbeddingBag input float module must have qconfig defined"
from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr]
weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator]
else:
weight_observer = float_qparams_weight_only_qconfig.weight()
dtype = weight_observer.dtype
is_float_qparams_qconfig = (
weight_observer.qscheme == torch.per_channel_affine_float_qparams
)
assert (
is_float_qparams_qconfig
), "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig."
assert (
dtype == torch.quint8 or dtype == torch.quint4x2
), f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}"
# Run the observer to calculate qparams.
weight_observer(mod.weight)
qweight = _quantize_weight(mod.weight.float(), weight_observer)
# Create quantized EmbeddingBag module and pass in the quantized weight
qembedding_bag = EmbeddingBag(
mod.num_embeddings, mod.embedding_dim, dtype=dtype
)
qembedding_bag.set_weight(qweight)
return qembedding_bag
@classmethod
def from_reference(cls, ref_embedding_bag):
qembedding_bag = cls(
ref_embedding_bag.num_embeddings,
ref_embedding_bag.embedding_dim,
ref_embedding_bag.max_norm,
ref_embedding_bag.norm_type,
ref_embedding_bag.scale_grad_by_freq,
ref_embedding_bag.mode,
ref_embedding_bag.sparse,
ref_embedding_bag.get_quantized_weight(),
ref_embedding_bag.include_last_offset,
ref_embedding_bag.weight_dtype,
)
return qembedding_bag