• Docs >
  • Distributed Quantization-Aware Training (QAT)
Shortcuts

Distributed Quantization-Aware Training (QAT)

QAT allows for taking advantage of memory-saving optimizations from quantization at inference time, without significantly degrading model performance. In torchtune, we use torchao to implement QAT. This works by simulating quantization numerics during fine-tuning. While this may introduce memory and compute overheads during training, our tests found that QAT significantly reduced performance degradation in evaluations of quantized model, without compromising on model size reduction gains.

Note

The PyTorch blogpost on QAT provides further insight into how QAT works.

We provide pre-tested out-of-the-box configs which you can get up and running with the latest Llama models in just two steps:

Note

You may need to be granted access to the Llama model you’re interested in. See here for details on accessing gated repositories.

tune download meta-llama/Meta-Llama-3-8B-Instruct  \
--output-dir /tmp/Meta-Llama-3-8B-Instruct \
--ignore-patterns "original/consolidated.00.pth" \
--HF_TOKEN <HF_TOKEN>

tune run --nproc_per_node 6 qat_distributed \
--config llama3/8B_qat_full

Note

This workload requires at least 6 GPUs, each with VRAM of at least 80GB.

Currently, the main lever you can pull for QAT is by using delayed fake quantization. Delayed fake quantization allows for control over the step after which fake quantization occurs. Empirically, allowing the model to finetune without fake quantization initially allows the weight and activation values to stabilize before fake quantizing them, potentially leading to improved quantized accuracy. This can be specified through fake_quant_after_n_steps. To provide you with an idea of how to roughly configure this parameter, we’ve achieved best results with fake_quant_after_n_steps ~= total_steps // 2.

In the future we plan to support different quantization strategies. For now, note that you’ll need at least torch>=2.4.0 to use the Int8DynActInt4WeightQATQuantizer strategy. Generally, the pipeline for training, quantizing, and evaluating a model using QAT is:

  1. Run the qat_distributed recipe using the above command, or by following the tutorial. By default, this will use Int8DynActInt4WeightQATQuantizer.

  2. This produces an un-quantized model in the original data type. To get an actual quantized model, follow this with tune run quantize while specifying the same quantizer in the config, e.g.

    # QAT specific args
    quantizer:
      _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
      groupsize: 256
    
  3. Evaluate or run inference using your your quantized model by specifying the corresponding post-training quantizer:

    quantizer:
      _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
      groupsize: 256
    

Note

We’re using config files to show how to customize the recipe in these examples. Check out the configs tutorial to learn more.

Many of our other memory optimization features can be used in this recipe, too:

You can learn more about all of our memory optimization features in our memory optimization overview.

Interested in seeing this recipe in action? Check out some of our tutorials to show off how it can be used:

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources