Shortcuts

Contributor Guide

General Guide on Extending torchao

For a new use case, for example, a training dtype (like fp4 training), it’s fine to start with adding a new tensor subclass in prototype folder torchao/prototype, but you could also take a look at AffineQuantizedTensor if what you want to do is mostly supported there, e.g. adding int3 kernel for the exact same affine quantization. Please feel free to open an issue and if you have questions on what to do for a specific new use case. For more details, please refer to our quantization overview page.

To contribute to existing code base:

Adding Efficient Kernels

Custom triton kernels

Custom triton kernels can be implemented and registered in torchao/kernel

You may need to define you own autotuner as well.

Custom hand written kernels

Custom kernels (implementations) for cpu/cuda/mps can be implemented through torchao/csrc e.g. int4 cuda, and accessible through torch.ops.my_custom_op

Dispatches

For dispatching to optimized kernels for cpu/cuda/mps devices, we can have checks for the dispatch conditions in __torch_function__ or __torch_dispatch__ and dispatch to target operators, for example, condition for bfloat16 activation and uint4 weight kernel can be found here.

Specifically for AffineQuantizedTensor, we also allow people to extend the quantized linear to use a new efficient kernel or implement by defining two functions: dispatch_condition (defines the condition to dispatch to the kernel) and impl (actual implementation that takes activation, (quantized) weight, bias Tensor and runs the efficient kernel), both taking input_tensor, weight_tensor, bias as argument, and can be registered into dispatch of quantized linear in AffineQuantizedTensor with register_aqt_quantized_linear_dispatch. Here is an example showing how it works.

Layout/TensorImpl

Sometimes the quantized weights has to be packed in order to yield optimal performance. And this can be abstracted with layout. See here for full example.

Flow

After the tensor subclass is implemented, we can also wrap that into factory functions, e.g.::

# convert from floating point tensor to my dtype tensor subclass to_my_dtype = MyDTypeTensor.from_float

For model level API, people can reuse torchao.quantize_ that allows people to apply a tensor subclass conversion to weight of linear, and allows filtering function to choose which module the tensor subclass conversion should be applied to.

See Quantization Algorithms/Flows section for examples of weight only/dynamic quant/static quant and other types of model level APIs based on the factory function.

Using torch.compile for Performance

Note: for pytorch 2.4 and below, we need to use the following::

from torchao.utils import unwrap_tensor_subclass m_unwrapped = unwrap_tensor_subclass(m)

In order to be compatible with torch.compile, to aim for performance optimization, we should run through torch.compile with fullgraph=True first, and remove any unnecessary graph breaks. You can add TORCH_LOGS="output_code" when you run the script in order to see the inductor generated code. e.g. TORCH_LOGS="output_code" python example.py::

model = torch.compile(model, mode=”max-autotune”, fullgraph=True)

Serialization

Please checkout the serialization doc for more details.

Note

We are integrated with huggingface transformer and supports serialization/deserialization through the huggingface save_pretrained/push_to_hub/from_pretrained APIs: https://huggingface.co/docs/transformers/main/en/quantization/torchao

Note

Another example can be found in integration with diffuser: https://github.com/sayakpaul/diffusers-torchao/blob/main/inference/serialization_and_loading.md

Other Feature Support

The above just talks about basic feature support, we also provide examples on how to add supports for training, tensor parallel, FSDP by extending the MyDTypeTensor, we’ll put more examples in developer_api_guide folder covering the following use cases.

Tensor Subclass Functionality/Composability Testing

We are also working on test suites to test out the functionalities of tensor subclass and the composability with different systems like torch.compile, DTensor etc. (we recommend to copy paste the tests and adapt to test your own tensor subclass for now):

Kernel Microbenchmarks

Before we test performance on models, we can also do some microbenchmarks on single linear operator (or other compute intensive/memory intensive) operators with different input dimensions to get a sense of speedup. For a specific kernel that you’d like to benchmark, you can create a benchmark file like benchmarks/benchmark_aq.py and run benchmark with different shapes that’s important for target model. A quick way to get the relevant shape for linear op and other ops is by running the example with this.

Change the model with the model you are interested in optimizing, and run the following:

python tutorials/developer_api_guide/print_op_and_shapes.py

Example output:

TORCH_FUNC=<built-in function linear> (M, K, N): 10 10 10
TORCH_FUNC=<method 'add' of 'torch._C.TensorBase' objects> args[0] shape: torch.Size([10, 10])

all linear shapes (M, K, N): [(10, 10, 10)]

The output of all linear shapes can be copy pasted to microbenchmarking script code under benchmarks/benchmark_your_kernel.py for benchmarking.

For benchmark helper functions, right now we have 1 and 2, feel free to use either one for now, but we’ll probably keep one in the future.

Model Benchmarks and Eval

After you have the quantization flow implemented, you can run benchmark and eval on llama (llama2/llama3) or sam models that are already modified to be friendly to torch.compile, and compare with existing techniques in torchao.

Note: llama model (llama2/llama3) is our representative model for memory bound models and sam is our representative model for compute bound models.

Please checkout the --help option for each of the script to understand the supported options, e.g. you can use --profile=profile_path to get the chrome trace of the run to understand detailed chrome trace.

Please let us know if there are any new important models that makes sense to be added to torchao model benchmark/eval folder.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources