Note
Go to the end to download the full example code
Compiling GPT2 using the Torch-TensorRT torch.compile
frontend¶
This example illustrates the state of the art model GPT2 optimized using
torch.compile
frontend of Torch-TensorRT. Install the following dependencies before compilation
pip install -r requirements.txt
GPT2 is a causal (unidirectional) transformer pretrained using language modeling on a very large corpus of text data. In this example, we use the GPT2 model available at HuggingFace and apply torch.compile on it to get the graph module representation of the graph. Torch-TensorRT converts this graph into an optimized TensorRT engine.
Import necessary libraries¶
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
Define the necessary parameters¶
Torch-TensorRT requires a GPU for successful compilation of the model.
MAX_LENGTH
is the maximum length the generated tokens can have. This corresponds to the length of the input prompt +
number of new tokens generated
MAX_LENGTH = 32
DEVICE = torch.device("cuda:0")
Model definition¶
We use AutoModelForCausalLM
class to load the pretrained GPT2 model from hugging face. kv_cache
is not supported in Torch-TRT currently so use_cache=False
with torch.no_grad():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = (
AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
)
.eval()
.cuda()
)
PyTorch inference¶
Tokenize a sample input prompt and get pytorch model outputs
prompt = "I enjoy walking with my cute dog"
model_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = model_inputs["input_ids"].cuda()
The generate()
API of the AutoModelForCausalLM
class is used for auto-regressive generation with greedy decoding.
pyt_gen_tokens = model.generate(
input_ids,
max_length=MAX_LENGTH,
use_cache=False,
pad_token_id=tokenizer.eos_token_id,
)
Torch-TensorRT compilation and inference¶
The input sequence length is dynamic, so we mark it using torch._dynamo.mark_dynamic
API.
We provide a (min, max) range of this value so that TensorRT knows in advance what values to optimize for.
Usually, this would be the context length for the model. We start with min=2
due to the 0/1 specialization
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
model.forward = torch.compile(
model.forward,
backend="tensorrt",
dynamic=None,
options={
"enabled_precisions": {torch.float32},
"disable_tf32": True,
"min_block_size": 1,
},
)
Auto-regressive generation loop for greedy decoding using TensorRT model The first token generation compiles the model using TensorRT and the second token encounters recompilation (which is an issue currently that would be resolved in the future)
trt_gen_tokens = model.generate(
inputs=input_ids,
max_length=MAX_LENGTH,
use_cache=False,
pad_token_id=tokenizer.eos_token_id,
)
Decode the output sentences of PyTorch and TensorRT¶
print(
"Pytorch model generated text: ",
tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
)
print("=============================")
print(
"TensorRT model generated text: ",
tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
)
The output sentences should look like
"""
Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
=============================
TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
"""
Total running time of the script: ( 0 minutes 0.000 seconds)