Demonstration of torch.export flow, common challenges and the solutions to address them
Authors: Ankith Gunapal, Jordi Ramon, Marcos Carranza
In the Introduction to torch.export Tutorial , we learned how to use torch.export.
This tutorial expands on the previous one and explores the process of exporting popular models with code, as well as addresses common challenges that may arise with torch.export
.
In this tutorial, you will learn how to export models for these use cases:
Video classifier (MViT)
Automatic Speech Recognition (OpenAI Whisper-Tiny)
Image Captioning (BLIP)
Promptable Image Segmentation (SAM2)
Each of the four models were chosen to demonstrate unique features of torch.export, as well as some practical considerations and issues faced in the implementation.
Key requirement for torch.export
: No graph break
torch.compile speeds up PyTorch code by using JIT to compile PyTorch code into optimized kernels. It optimizes the given model
using TorchDynamo
and creates an optimized graph , which is then lowered into the hardware using the backend specified in the API.
When TorchDynamo encounters unsupported Python features, it breaks the computation graph, lets the default Python interpreter
handle the unsupported code, and then resumes capturing the graph. This break in the computation graph is called a graph break.
One of the key differences between torch.export
and torch.compile
is that torch.export
doesn’t support graph breaks
which means that the entire model or part of the model that you are exporting needs to be a single graph. This is because handling graph breaks
involves interpreting the unsupported operation with default Python evaluation, which is incompatible with what torch.export
is
designed for. You can read details about the differences between the various PyTorch frameworks in this link
You can identify graph breaks in your program by using the following command:
TORCH_LOGS="graph_breaks" python <file_name>.py
You will need to modify your program to get rid of graph breaks. Once resolved, you are ready to export the model. PyTorch runs nightly benchmarks for torch.compile on popular HuggingFace and TIMM models. Most of these models have no graph breaks.
The models in this recipe have no graph breaks, but fail with torch.export.
Video Classification
MViT is a class of models based on MultiScale Vision Transformers. This model has been trained for video classification using the Kinetics-400 Dataset. This model with a relevant dataset can be used for action recognition in the context of gaming.
The code below exports MViT by tracing with batch_size=2
and then checks if the ExportedProgram can run with batch_size=4
.
import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb
model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)
# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)
# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
# Export the model.
exported_program = torch.export.export(
model,
(input_frames,),
)
# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
exported_program.module()(input_frames)
except Exception:
tb.print_exc()
Error: Static batch size
raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4
By default, the exporting flow will trace the program assuming that all input shapes are static, so if you run the program with input shapes that are different than the ones you used while tracing, you will run into an error.
Solution
To address the error, we specify the first dimension of the input (batch_size
) to be dynamic , specifying the expected range of batch_size
.
In the corrected example shown below, we specify that the expected batch_size
can range from 1 to 16.
One detail to notice that min=2
is not a bug and is explained in The 0/1 Specialization Problem. A detailed description of dynamic shapes
for torch.export
can be found in the export tutorial. The code shown below demonstrates how to export mViT with dynamic batch sizes:
import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb
model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)
# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)
# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
# Export the model.
batch_dim = torch.export.Dim("batch", min=2, max=16)
exported_program = torch.export.export(
model,
(input_frames,),
# Specify the first dimension of the input x as dynamic
dynamic_shapes={"x": {0: batch_dim}},
)
# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
exported_program.module()(input_frames)
except Exception:
tb.print_exc()
Automatic Speech Recognition
Automatic Speech Recognition (ASR) is the use of machine learning to transcribe spoken language into text.
Whisper is a Transformer based encoder-decoder model from OpenAI, which was trained on 680k hours of labelled data for ASR and speech translation.
The code below tries to export whisper-tiny
model for ASR.
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id
model.eval()
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,))
Error: strict tracing with TorchDynamo
torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'DynamicCache' object has no attribute 'key_cache'
By default torch.export
traces your code using TorchDynamo, a byte-code analysis engine, which symbolically analyzes your code and builds a graph.
This analysis provides a stronger guarantee about safety but not all Python code is supported. When we export the whisper-tiny
model using the
default strict mode, it typically returns an error in Dynamo due to an unsupported feature. To understand why this errors in Dynamo, you can refer to this GitHub issue.
Solution
To address the above error , torch.export
supports the non_strict
mode where the program is traced using the Python interpreter, which works similar to
PyTorch eager execution. The only difference is that all Tensor
objects will be replaced by ProxyTensors
, which will record all their operations into
a graph. By using strict=False
, we are able to export the program.
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id
model.eval()
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,), strict=False)
Promptable Image Segmentation
Image segmentation is a computer vision technique that divides a digital image into distinct groups of pixels, or segments, based on their characteristics. Segment Anything Model (SAM)) introduced promptable image segmentation, which predicts object masks given prompts that indicate the desired object. SAM 2 is the first unified model for segmenting objects across images and videos. The SAM2ImagePredictor class provides an easy interface to the model for prompting the model. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction. Since SAM2 provides strong zero-shot performance for object tracking, it can be used for tracking game objects in a scene.
The tensor operations in the predict method of SAM2ImagePredictor are happening in the _predict method. So, we try to export like this.
ep = torch.export.export(
self._predict,
args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
kwargs={"return_logits": return_logits},
strict=False,
)
Error: Model is not of type torch.nn.Module
torch.export
expects the module to be of type torch.nn.Module
. However, the module we are trying to export is a class method. Hence it errors.
Traceback (most recent call last):
File "/sam2/image_predict.py", line 20, in <module>
masks, scores, _ = predictor.predict(
File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict
ep = torch.export.export(
File "python3.10/site-packages/torch/export/__init__.py", line 359, in export
raise ValueError(
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>.
Solution
We write a helper class, which inherits from torch.nn.Module
and call the _predict method
in the forward
method of the class. The complete code can be found here.
class ExportHelper(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(_, *args, **kwargs):
return self._predict(*args, **kwargs)
model_to_export = ExportHelper()
ep = torch.export.export(
model_to_export,
args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
kwargs={"return_logits": return_logits},
strict=False,
)
Conclusion
In this tutorial, we have learned how to use torch.export
to export models for popular use cases by addressing challenges through correct configuration and simple code modifications.
Once you are able to export a model, you can lower the ExportedProgram
into your hardware using AOTInductor in case of servers and ExecuTorch in case of edge device.
To learn more about AOTInductor
(AOTI), please refer to the AOTI tutorial.
To learn more about ExecuTorch
, please refer to the ExecuTorch tutorial.