• Docs >
  • Dynamic shapes with Torch-TensorRT

Dynamic shapes with Torch-TensorRT

By default, you can run a pytorch model with varied input shapes and the output shapes are determined eagerly. However, Torch-TensorRT is an AOT compiler which requires some prior information about the input shapes to compile and optimize the model. In the case of dynamic input shapes, we must provide the (min_shape, opt_shape, max_shape) arguments so that the model can be optimized for these range of input shapes. An example usage of static and dynamic shapes is as follows.

NOTE: The following code uses Dynamo Frontend. Incase of Torchscript Frontend, please swap out ir=dynamo with ir=ts and the behavior is exactly the same.

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
# Compile with static shapes
inputs = torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32)
# or compile with dynamic shapes
inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224],
                              opt_shape=[4, 3, 224, 224],
                              max_shape=[8, 3, 224, 224],
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)

Under the hood

There are two phases of compilation when we use torch_tensorrt.compile API with ir=dynamo (default).

  • aten_tracer.trace (which uses torch.export to trace the graph with the given inputs)

In the tracing phase, we use torch.export along with the constraints. In the case of dynamic shaped inputs, the range can be provided to the tracing via constraints. Please refer to this docstring for detailed information on how to set constraints. In short, we create new inputs for torch.export tracing and provide constraints on the min and max values(provided by the user), a particular dimension can take. Please take a look at aten_tracer.py file to understand how this works under the hood.

  • dynamo.compile (which compiles a torch.fx.GraphModule object using TensorRT)

In the conversion to TensorRT, we use the user provided dynamic shape inputs. We perform shape analysis using dummy inputs (across min, opt and max shapes) and store the intermediate output shapes which can be used in case the graph has a mix of Pytorch and TensorRT submodules.

Custom Constraints

Given an input x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype), Torch-TensorRT automatically sets the constraints during torch.export tracing as follows

for dim in constraint_dims:
    if min_shape[dim] > 1:
        constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim))
    if max_shape[dim] > 1:
        constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim])

Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don’t specify them. For example, in the case of BERT model compilation, there are two inputs and a constraint has to be set involving the sequence length size of these two inputs.

constraints.append(dynamic_dim(trace_inputs[0], 0) == dynamic_dim(trace_inputs[1], 0))

If you have to provide any custom constraints to your model, the overall workflow for model compilation using ir=dynamo would involve a few steps.

import torch
import torch_tensorrt
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
# Assume the model has two inputs
model = MyModel()
torch_input_1 = torch.randn((1, 14), dtype=torch.int32).cuda()
torch_input_2 = torch.randn((1, 14), dtype=torch.int32).cuda()

dynamic_inputs = [torch_tensorrt.Input(min_shape=[1, 14],
                    opt_shape=[4, 14],
                    max_shape=[8, 14],
                  torch_tensorrt.Input(min_shape=[1, 14],
                    opt_shape=[4, 14],
                    max_shape=[8, 14],

# Export the model with additional constraints
constraints = []
# The following constraints are automatically added by Torch-TensorRT in the
# general case when you call torch_tensorrt.compile directly on MyModel()
constraints.append(dynamic_dim(torch_input_1, 0) < 8)
constraints.append(dynamic_dim(torch_input_2, 0) < 8)
# This is an additional constraint as instructed by Torchdynamo
constraints.append(dynamic_dim(torch_input_1, 0) == dynamic_dim(torch_input_2, 0))
with unittest.mock.patch(
    "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
    graph_module = export(
        model, (torch_input_1, torch_input_2), constraints=constraints

# Use the dynamo.compile API
trt_mod = torch_tensorrt.dynamo.compile(graph_module, inputs=dynamic_inputs, **compile_spec)


If there are operations in the graph that use the dynamic dimension of the input, Pytorch introduces torch.ops.aten.sym_size.int ops in the graph. Currently, we cannot handle these operators and the compilation results in undefined behavior. We plan to add support for these operators and implement robust support for shape tensors in the next release. Here is an example of the limitation described above

import torch
import torch_tensorrt

class MyModule(torch.nn.Module):
    def __init__(self):
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.avgpool(x)
        out = torch.flatten(x, 1)
        return out

model = MyModel().eval().cuda()
# Compile with dynamic shapes
inputs = torch_tensorrt.Input(min_shape=(1, 512, 1, 1),
                     opt_shape=(4, 512, 1, 1),
                     max_shape=(8, 512, 1, 1),
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)

The traced graph of MyModule() looks as follows

Post export graph: graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%arg0_1, [-1, -2], True), kwargs = {})
%sym_size : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%arg0_1, 0), kwargs = {})
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mean, [%sym_size, 512]), kwargs = {})
return (view,)

Here the %sym_size node captures the dynamic batch and uses it in the aten.view layer. This requires shape tensors support which would be a part of our next release.

Workaround (BERT static compilation example)

In the case where you encounter the issues mentioned in the Limitations section, you can compile the model (static mode) with max input size that can be provided. In the cases of smaller inputs, we can pad them accordingly. This is only a workaround until we address the limitations.

import torch
import torch_tensorrt
from transformers.utils.fx import symbolic_trace as transformers_trace

model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()

# Input sequence length is 20.
input1 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda")
input2 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda")

model = transformers_trace(model, input_names=["input_ids", "attention_mask"]).eval().cuda()
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec)
model_outputs = model(input, input2)

# If you have a sequence of length 14, pad 6 zero tokens and run inference
# or recompile for sequence length of 14.
input1 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda")
input2 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda")
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec)
model_outputs = model(input, input2)

Dynamic shapes with ir=torch_compile

torch_tensorrt.compile(model, inputs, ir="torch_compile") returns a torch.compile boxed function with the backend configured to Tensorrt. In the case of ir=torch_compile, when the input size changes, Dynamo will trigger a recompilation of the TensorRT engine automatically giving dynamic shape behavior similar to native PyTorch eager however with the cost of rebuilding TRT engine. This limitation will be addressed in future versions of Torch-TensorRT.

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224), dtype=float32)
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs)
# Compilation happens when you call the model

# Recompilation happens with modified batch size
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32)
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs_bs2)


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources