Shortcuts

Source code for torch.ao.nn.quantized.modules.embedding_ops

# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
from torch import Tensor  # noqa: F401
from torch._jit_internal import List, Optional  # noqa: F401

from .utils import _hide_packed_params_repr, _quantize_weight


__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]


class EmbeddingPackedParams(torch.nn.Module):
    _version = 1

    def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
        super().__init__()
        self.dtype = dtype
        if self.dtype in [torch.quint8, torch.quint4x2]:
            scales = torch.ones(num_embeddings, dtype=torch.float)
            zero_points = torch.zeros(num_embeddings, dtype=torch.float)
            wq = torch._empty_per_channel_affine_quantized(
                [num_embeddings, embedding_dim],
                scales=scales,
                zero_points=zero_points,
                axis=0,
                dtype=self.dtype,
            )
            self.set_weight(wq)
        else:
            raise NotImplementedError(
                f"Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}"
            )

    @torch.jit.export
    def set_weight(self, weight: torch.Tensor) -> None:
        if self.dtype in [torch.quint8, torch.quint4x2]:
            self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
        else:
            raise NotImplementedError(
                "Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2."
            )

    @torch.jit.export
    def _weight(self):
        if self.dtype in [torch.quint8, torch.quint4x2]:
            return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
        else:
            raise NotImplementedError(
                "Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2."
            )

    def forward(self, x):
        return x

    # Version 1
    #   self
    #   |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase
    #   |--- dtype : torch.dtype

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        super()._save_to_state_dict(destination, prefix, keep_vars)
        destination[prefix + "dtype"] = self.dtype
        destination[prefix + "_packed_weight"] = self._weight()

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        self.dtype = state_dict[prefix + "dtype"]
        state_dict.pop(prefix + "dtype")

        weight = state_dict[prefix + "_packed_weight"]
        state_dict.pop(prefix + "_packed_weight")
        self.set_weight(weight)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            False,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

    def __repr__(self):
        return self._weight().__repr__()


[docs]class Embedding(torch.nn.Module): r""" A quantized Embedding module with quantized packed weights as inputs. We adopt the same interface as `torch.nn.Embedding`, please see https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html for documentation. Similar to :class:`~torch.nn.Embedding`, attributes will be randomly initialized at module creation time and will be overwritten later Attributes: weight (Tensor): the non-learnable quantized weights of the module of shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. Examples:: >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12) >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8]) >>> output = m(indices) >>> print(output.size()) torch.Size([9, 12]) """ _version = 1 def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, _weight: Optional[Tensor] = None, dtype=torch.quint8, ) -> None: super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.dtype = dtype if _weight is None: scales = torch.ones(num_embeddings, dtype=torch.float) zero_points = torch.zeros(num_embeddings, dtype=torch.float) qweight = torch._empty_per_channel_affine_quantized( [num_embeddings, embedding_dim], scales=scales, zero_points=zero_points, axis=0, dtype=torch.quint8, ) else: assert list(_weight.shape) == [ num_embeddings, embedding_dim, ], "Shape of weight does not match num_embeddings and embedding_dim" qweight = _weight self._packed_params = EmbeddingPackedParams( num_embeddings, embedding_dim, dtype ) self._packed_params.set_weight(qweight) def forward(self, indices: Tensor) -> Tensor: if self.dtype == torch.quint4x2: return torch.ops.quantized.embedding_4bit( self._packed_params._packed_weight, indices ) else: return torch.ops.quantized.embedding_byte( self._packed_params._packed_weight, indices ) def _get_name(self): return "QuantizedEmbedding" def __repr__(self): return _hide_packed_params_repr(self, EmbeddingPackedParams) def extra_repr(self): extra_repr_str = ( f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, " f"dtype={self._packed_params.dtype}, qscheme={self.weight().qscheme()}" ) return extra_repr_str def set_weight(self, w: torch.Tensor) -> None: self._packed_params.set_weight(w) def weight(self): return self._packed_params._weight()
[docs] @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized embedding module from a float module Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by user """ if hasattr(mod, "weight_fake_quant"): assert type(mod) == torch.ao.nn.qat.Embedding, ( "nnq." + cls.__name__ + ".from_float " + "with fake quant only works for " + torch.ao.nn.qat.Embedding.__name__ ) weight_observer = mod.weight_fake_quant activation_post_process = mod.activation_post_process else: assert type(mod) == nn.Embedding, ( "nnq." + cls.__name__ + ".from_float only works for " + nn.Embedding.__name__ ) assert hasattr( mod, "qconfig" ), "Embedding input float module must have qconfig defined" from torch.ao.quantization import float_qparams_weight_only_qconfig if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator] else: weight_observer = float_qparams_weight_only_qconfig.weight() dtype = weight_observer.dtype is_float_qparams_qconfig = ( weight_observer.qscheme == torch.per_channel_affine_float_qparams ) assert ( is_float_qparams_qconfig ), "Embedding quantization is only supported with float_qparams_weight_only_qconfig." assert ( dtype == torch.quint8 or dtype == torch.quint4x2 ), f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}" # Run the observer to calculate qparams. weight_observer(mod.weight) qweight = _quantize_weight(mod.weight.float(), weight_observer) # Create quantized Embedding module and pass in the quantized weight qembedding = Embedding(mod.num_embeddings, mod.embedding_dim) qembedding.set_weight(qweight) return qembedding
@classmethod def from_reference(cls, ref_embedding): qembedding = cls( ref_embedding.num_embeddings, ref_embedding.embedding_dim, ref_embedding.padding_idx, ref_embedding.max_norm, ref_embedding.norm_type, ref_embedding.scale_grad_by_freq, ref_embedding.sparse, ref_embedding.get_quantized_weight(), ref_embedding.weight_dtype, ) return qembedding
[docs]class EmbeddingBag(Embedding): r""" A quantized EmbeddingBag module with quantized packed weights as inputs. We adopt the same interface as `torch.nn.EmbeddingBag`, please see https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html for documentation. Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly initialized at module creation time and will be overwritten later Attributes: weight (Tensor): the non-learnable quantized weights of the module of shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. Examples:: >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum') >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32]) >>> output = m(indices, offsets) >>> print(output.size()) torch.Size([5, 12]) """ _version = 1 def __init__( self, num_embeddings: int, embedding_dim: int, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = "sum", sparse: bool = False, _weight: Optional[Tensor] = None, include_last_offset: bool = False, dtype=torch.quint8, ) -> None: super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype) self.mode = mode self.pruned_weights = False self.include_last_offset = include_last_offset self.dtype = dtype def forward( self, indices: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None, compressed_indices_mapping: Optional[Tensor] = None, ) -> Tensor: if self.dtype == torch.quint4x2: return torch.ops.quantized.embedding_bag_4bit( self._packed_params._packed_weight, indices, offsets, False, 0, self.pruned_weights, per_sample_weights, compressed_indices_mapping, self.include_last_offset, ) else: return torch.ops.quantized.embedding_bag_byte( self._packed_params._packed_weight, indices, offsets, False, 0, self.pruned_weights, per_sample_weights, compressed_indices_mapping, self.include_last_offset, ) def _get_name(self): return "QuantizedEmbeddingBag"
[docs] @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized embedding_bag module from a float module Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by user """ if hasattr(mod, "weight_fake_quant"): weight_observer = mod.weight_fake_quant else: assert type(mod) == nn.EmbeddingBag, ( "nnq." + cls.__name__ + ".from_float only works for " + nn.EmbeddingBag.__name__ ) assert hasattr( mod, "qconfig" ), "EmbeddingBag input float module must have qconfig defined" from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator] else: weight_observer = float_qparams_weight_only_qconfig.weight() dtype = weight_observer.dtype is_float_qparams_qconfig = ( weight_observer.qscheme == torch.per_channel_affine_float_qparams ) assert ( is_float_qparams_qconfig ), "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig." assert ( dtype == torch.quint8 or dtype == torch.quint4x2 ), f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}" # Run the observer to calculate qparams. weight_observer(mod.weight) qweight = _quantize_weight(mod.weight.float(), weight_observer) # Create quantized EmbeddingBag module and pass in the quantized weight qembedding_bag = EmbeddingBag( mod.num_embeddings, mod.embedding_dim, max_norm=mod.max_norm, norm_type=mod.norm_type, scale_grad_by_freq=mod.scale_grad_by_freq, mode=mod.mode, sparse=mod.sparse, include_last_offset=mod.include_last_offset, dtype=dtype, ) qembedding_bag.set_weight(qweight) return qembedding_bag
@classmethod def from_reference(cls, ref_embedding_bag): qembedding_bag = cls( ref_embedding_bag.num_embeddings, ref_embedding_bag.embedding_dim, ref_embedding_bag.max_norm, ref_embedding_bag.norm_type, ref_embedding_bag.scale_grad_by_freq, ref_embedding_bag.mode, ref_embedding_bag.sparse, ref_embedding_bag.get_quantized_weight(), ref_embedding_bag.include_last_offset, ref_embedding_bag.weight_dtype, ) return qembedding_bag

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources