• Docs >
  • Quantized Operations for XLA device (Experimental feature)
Shortcuts

Quantized Operations for XLA device (Experimental feature)


This document outlines how to utilize quantized operations to enable quantization on XLA devices.

XLA Quantized ops offer a high-level abstraction for quantized operations (e.g., blockwise int4 quantized matrix multiplication). These ops are analogous to quantized CUDA kernels (example) in the CUDA ecosystem, providing similar functionality and performance benefits within the XLA framework.

NOTE: Currently this is classified as experimental feature. It’s API specifics will change in the next (2.5) release.

How to use:

XLA quantized operations can be used as torch op, or a torch.nn.Module that wraps the torch.op. These 2 options give model developers the flexibility to choose the best way to integrate XLA quantized ops into their solution.

Both torch op and nn.Module are compatible with torch.compile( backend='openxla').

Call XLA quantized op in model code

Users can call XLA quantized ops in the same way as calling other regular PyTorch ops. This provides maximum flexibility in integrating XLA quantized ops into their applications. The quantized ops work in both eager mode and Dynamo, with regular PyTorch CPU tensor and XLA tensor.

Note Please check the docstring of the quantized ops for the layout of the quantized weights.

import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_quantized_matmul

N_INPUT_FEATURES=10
N_OUTPUT_FEATURES=20
x = torch.randn((3, N_INPUT_FEATURES), dtype=torch.bfloat16)
w_int = torch.randint(-128, 127, (N_OUTPUT_FEATURES, N_INPUT_FEATURES), dtype=torch.int8)
scaler = torch.randn((N_OUTPUT_FEATURES,), dtype=torch.bfloat16)

# Call with torch CPU tensor (For debugging purpose)
matmul_output = torch.ops.xla.quantized_matmul(x, w_int, scaler)

device = xm.xla_device()
x_xla = x.to(device)
w_int_xla = w_int.to(device)
scaler_xla = scaler.to(device)

# Call with XLA Tensor to run on XLA device
matmul_output_xla = torch.ops.xla.quantized_matmul(x_xla, w_int_xla, scaler_xla)

# Use with torch.compile(backend='openxla')
def f(x, w, s):
  return torch.ops.xla.quantized_matmul(x, w, s)

f_dynamo = torch.compile(f, backend="openxla")
dynamo_out_xla = f_dynamo(x_xla, w_int_xla, scaler_xla)

It’s common to wrap the quantized op into a custom nn.Module in model developers model code:

class MyQLinearForXLABackend(torch.nn.Module):
  def __init__(self):
    self.weight = ...
    self.scaler = ...

  def load_weight(self, w, scaler):
    # Load quantized Linear weights
    # Customized way to preprocess the weights
    ...
    self.weight = processed_w
    self.scaler = processed_scaler


  def forward(self, x):
    # Do some random stuff with x
    ...
    matmul_output = torch.ops.xla.quantized_matmul(x, self.weight, self.scaler)
    # Do some random stuff with matmul_output
    ...

Module Swap

Alternatively, users can also use the nn.Module that wraps the XLA quantized ops and do module swap in the model code:

orig_model = MyModel()
# Quantize the model and get quantized weights
q_weights = quantize(orig_model)
# Process the quantized weight to the format that XLA quantized op expects.
q_weights_for_xla = process_for_xla(q_weights)

# Do module swap
q_linear = XlaQuantizedLinear(self.linear.in_features,
                              self.linear.out_features)
q_linear.load_quantized_weight(q_weights_for_xla)
orig_model.linear = q_linear

Supported Quantized Operations:

Matrix Multiply

Weight Quantization Type

Activation Quantization Type

Dtype

Supported

per-channel (sym/asym)

N/A

W8A16

Yes

per-channel (sym/asym)

N/A

W4A16

Yes

per-channel

per-token

W8A8

No

per-channel

per-token

W4A8

No

blockwise (sym/asym)

N/A

W8A16

Yes

blockwise (sym/asym)

N/A

W4A16

Yes

blockwise

per-token

W8A8

No

blockwise

per-token

W4A8

No

Note W[X]A[Y] refers to Weight in X-bit, Activation in Y-bit. If X/Y is 4 or 8, it refers to int4/8. 16 for bfloat16 format.

Embedding

To be added

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources