Shortcuts

Quick Start Guide

In this quick start guide, we will explore how to perform basic quantization using torchao. First, install the latest stable torchao release:

pip install torchao

If you prefer to use the nightly release, you can install torchao using the following command instead:

pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121

torchao is compatible with the latest 3 major versions of PyTorch, which you will also need to install (detailed instructions):

pip install torch

First Quantization Example

The main entry point for quantization in torchao is the quantize_ API. This function mutates your model inplace to insert the custom quantization logic based on what the user configures. All code in this guide can be found in this example script. First, let’s set up our toy model:

import copy
import torch

class ToyLinearModel(torch.nn.Module):
    def __init__(self, m: int, n: int, k: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, n, bias=False)
        self.linear2 = torch.nn.Linear(n, k, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")

# Optional: compile model for faster inference and generation
model = torch.compile(model, mode="max-autotune", fullgraph=True)
model_bf16 = copy.deepcopy(model)

Now we call our main quantization API to quantize the linear weights in the model to int4 inplace. More specifically, this applies uint4 weight-only asymmetric per-group quantization, leveraging the tinygemm int4mm CUDA kernel for efficient mixed dtype matrix multiplication:

# torch 2.4+ only
from torchao.quantization import int4_weight_only, quantize_
quantize_(model, int4_weight_only(group_size=32))

The quantized model is now ready to use! Note that the quantization logic is inserted through tensor subclasses, so there is no change to the overall model structure; only the weights tensors are updated, but nn.Linear modules stay as nn.Linear modules:

>>> model.linear1
Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15))

>>> model.linear2
Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15))

First, verify that the int4 quantized model is roughly a quarter of the size of the original bfloat16 model:

>>> import os
>>> torch.save(model, "/tmp/int4_model.pt")
>>> torch.save(model_bf16, "/tmp/bfloat16_model.pt")
>>> int4_model_size_mb = os.path.getsize("/tmp/int4_model.pt") / 1024 / 1024
>>> bfloat16_model_size_mb = os.path.getsize("/tmp/bfloat16_model.pt") / 1024 / 1024

>>> print("int4 model size: %.2f MB" % int4_model_size_mb)
int4 model size: 1.25 MB

>>> print("bfloat16 model size: %.2f MB" % bfloat16_model_size_mb)
bfloat16 model size: 4.00 MB

Next, we demonstrate that not only is the quantized model smaller, it is also much faster!

from torchao.utils import (
    TORCH_VERSION_AT_LEAST_2_5,
    benchmark_model,
    unwrap_tensor_subclass,
)

# Temporary workaround for tensor subclass + torch.compile
# Only needed for torch version < 2.5
if not TORCH_VERSION_AT_LEAST_2_5:
    unwrap_tensor_subclass(model)

num_runs = 100
torch._dynamo.reset()
example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),)
bf16_time = benchmark_model(model_bf16, num_runs, example_inputs)
int4_time = benchmark_model(model, num_runs, example_inputs)

print("bf16 mean time: %0.3f ms" % bf16_time)
print("int4 mean time: %0.3f ms" % int4_time)
print("speedup: %0.1fx" % (bf16_time / int4_time))

On a single A100 GPU with 80GB memory, this prints:

bf16 mean time: 30.393 ms
int4 mean time: 4.410 ms
speedup: 6.9x

Next Steps

In this quick start guide, we learned how to quantize a simple model with torchao. To learn more about the different workflows supported in torchao, see our main README. For a more detailed overview of quantization in torchao, visit this page.

Finally, if you would like to contribute to torchao, don’t forget to check out our contributor guide and our list of good first issues on Github!

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources