Skip to main content
Blog

High-performance quantized LLM inference on Intel CPUs with native PyTorch

By September 17, 2025No Comments

PyTorch 2.8 has just been released with a set of exciting new features, including a limited stable libtorch ABI for third-party C++/CUDA extensions, high-performance quantized LLM inference on Intel CPUs with native PyTorch, experimental Wheel Variant Support, inductor CUTLASS backend support, etc. Among all these features, one of the great things is that PyTorch can now provide competitive Large Language Model (LLM) low-precision performance on Intel Xeon platform as compared with other popular LLM frameworks.  

In PyTorch 2.8, we enabled and optimized common quantization configs for LLM on Intel Xeon processors, including A16W8, DA8W8, and A16W4, etc. When using torch.compile in the quantized model, we lower the patterns of quantized GEMM to template-based high-performance GEMM kernels with max-autotune in Inductor, which will then leverage Intel AMX and Intel AVX-512 capabilities to accelerate the processing. 

With this feature, the performance with PyTorch native stack can reach the same level or even better in some cases, as compared with the popular LLM serving framework vLLM when running offline mode on a single Intel Xeon CPU computation node. Comparison of TTFT and TPOT between PyTorch native and vLLM is listed below (with Llama-3.1-8B as the benchmarking model) for different low precision configurations, including DA8W8, A16W4, and DA8W4. 

As we can find in the graphs, PyTorch stack reaches similar or better performance for most of the test configurations. It is worth noting that vLLM is tested in offline mode instead of serving mode, which aligns with the configuration for running native Pytorch.

To use these features and get boosted performance, users simply need to

  • Pick a machine with an X86 CPU with AMX support.
  • Quantize a model with Torchao’s quantization method.
  • Set a few flags for torch.compilestyle=”font-weight: 400;”> for the best performance.
  • Compile the model with torch.compile.

Then optimizations are applied automatically under the hood. Here is an example.

# 1. Set torch.compile flags
from torch._inductor import config as inductor_config
inductor_config.cpp_wrapper = True
inductor_config.max_autotune = True
inductor_config.cpp.enable_concat_linear = True
inductor_config.cpp.use_small_dequant_buffer = True


# 2. Get model
model = transformers.AutoModelForCausalLM.from_pretrained(<model_id>, ...)

# 3. Quantization with Torchao
from torchao.quantization.quant_api import (
    quantize_,
    Int8DynamicActivationInt8WeightConfig,
    Int4WeightOnlyConfig,
    Int8DynamicActivationInt4WeightConfig,
)

## 3.1 DA8W8
quantize_(
    model,
    Int8DynamicActivationInt8WeightConfig(set_inductor_config=False)
)
## 3.2 A16W4
quantize_(
    model,
    Int4WeightOnlyConfig(
        group_size=128,
        int4_packing_format="opaque",
        set_inductor_config=False,
    )
)
## 3.3 DA8W4
from torchao.dtypes import Int8DynamicActInt4WeightCPULayout
from torchao.quantization.quant_primitives import MappingType
quantize_(
    model,
    Int8DynamicActivationInt4WeightConfig(
        group_size=128,
        layout=Int8DynamicActInt4WeightCPULayout(),
        act_mapping_type=MappingType.SYMMETRIC,
        set_inductor_config=False,
    )
)

# 4. Apply optimizations with torch.compile
model.forward = torch.compile(model.forward)

# 5. Run the quantized and optimized model
# Preparation of input_ids and generate_kwargs are not shown
model.generate(
    input_ids, **generate_kwargs
)

Summary 

We discussed the ability to achieve competitive LLM performance on Intel Xeon processors as compared to the popular LLM framework in PyTorch 2.8.  This feature is designed to enable PyTorch users to run LLM Weight-only Quantization (WOQ) using int8 and int4 precision with native experience and latest performance optimization on Intel hardware. The current optimization is based on a single Intel Xeon platform device, and we will move on to support multiple Intel Xeon platform-based inference, then we can have more advanced features like Tensor Parallel to be effective as well. 

Acknowledgements 

The release of PyTorch 2.8 is an exciting milestone for Intel Xeon platforms, and it would not have been possible without the deep collaboration and contributions from the community. We extend our heartfelt thanks to Alban D, Andrey Talman, Bin Bao, Jason Ansel, Jerry Zhang, and Nikita Shulga for sharing their invaluable ideas, meticulously reviewing PRs, and providing insightful feedback on RFCs. Their dedication has driven continuous improvements and pushed the ecosystem forward for Intel platforms. 

 References  

Product and Performance Information 

Measurement on 1-node, 2x Intel(R) Xeon(R) 6980P, 128 cores, 500W TDP, HT On, Turbo On, Total Memory 1536GB (24x64GB DDR5 12800MT/s [8800MT/s]), SNC 3, BIOS BHSDCRB1.IPC.3544.D02.2410010029, microcode 0x11000314, 1x I210 Gigabit Network Connection, 1x 3.5T INTEL SSDPF2KX038TZ, 1x 894.3G Micron_7450_MTFDKBG960TFR, 2x 1.8T INTEL SSDPE2MX020T4F, 1x 1.8T INTEL SSDPE2KX020T8, CentOS Stream 9, 6.6.0-gnr.bkc.6.6.16.8.23.x86_64, only one SNC is used for testing, input token length: 1K, output token length: 128, vLLM is tested with offline mode. Test by Intel on July 2025 with PyTorch 2.8 RC, TorchAO (3addf30), vLLM v0.8.5. 

Notices and Disclaimers 

Performance varies by use, configuration, and other factors. Learn more on the Performance Index site. Performance results are based on testing as of dates shown in configurations and may not reflect all publicly available updates.  See backup for configuration details.  No product or component can be absolutely secure. Your costs and results may vary. Intel technologies may require enabled hardware, software, or service activation. 

Intel Corporation. Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries. Other names and brands may be claimed as the property of others. 

AI disclaimer:
AI features may require software purchase, subscription, or enablement by a software or platform provider, or may have specific configuration or compatibility requirements. Details at www.intel.com/AIPC. Results may vary.