TensorRT Backend for torch.compile
¶
This guide presents the Torch-TensorRT torch.compile backend: a deep learning compiler which uses TensorRT to accelerate JIT-style workflows across a wide variety of models.
Key Features¶
The primary goal of the Torch-TensorRT torch.compile backend is to enable Just-In-Time compilation workflows by combining the simplicity of torch.compile API with the performance of TensorRT. Invoking the torch.compile backend is as simple as importing the torch_tensorrt package and specifying the backend:
import torch_tensorrt
...
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False)
Note
Many additional customization options are available to the user. These will be discussed in further depth in this guide.
The backend can handle a variety of challenging model structures and offers a simple-to-use interface for effective acceleration of models. Additionally, it has many customization options to ensure the compilation process is fitting to the specific use case.
Customizable Settings¶
- class torch_tensorrt.dynamo.CompilationSettings(enabled_precisions: ~typing.Set[~torch_tensorrt._enums.dtype] = <factory>, debug: bool = False, workspace_size: int = 0, min_block_size: int = 5, torch_executed_ops: ~typing.Collection[~typing.Union[~typing.Callable[[...], ~typing.Any], str]] = <factory>, pass_through_build_failures: bool = False, max_aux_streams: ~typing.Optional[int] = None, version_compatible: bool = False, optimization_level: ~typing.Optional[int] = None, use_python_runtime: ~typing.Optional[bool] = False, truncate_double: bool = False, use_fast_partitioner: bool = True, enable_experimental_decompositions: bool = False, device: ~torch_tensorrt._Device.Device = <factory>, require_full_compilation: bool = False, disable_tf32: bool = False, assume_dynamic_shape_support: bool = False, sparse_weights: bool = False, make_refittable: bool = False, engine_capability: ~torch_tensorrt._enums.EngineCapability = <factory>, num_avg_timing_iters: int = 1, dla_sram_size: int = 1048576, dla_local_dram_size: int = 1073741824, dla_global_dram_size: int = 536870912, dryrun: ~typing.Union[bool, str] = False, hardware_compatible: bool = False, timing_cache_path: str = '/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init: bool = False, cache_built_engines: bool = False, reuse_cached_engines: bool = False, use_explicit_typing: bool = False, use_fp32_acc: bool = False, enable_weight_streaming: bool = False, enable_cross_compile_for_windows: bool = False)[source]¶
Compilation settings for Torch-TensorRT Dynamo Paths
- Parameters
enabled_precisions (Set[dpython:type]) – Available kernel dtype precisions
debug (bool) – Whether to print out verbose debugging information
workspace_size (python:int) – Workspace TRT is allowed to use for the module (0 is default)
min_block_size (python:int) – Minimum number of operators per TRT-Engine Block
torch_executed_ops (Collection[Target]) – Collection of operations to run in Torch, regardless of converter coverage
pass_through_build_failures (bool) – Whether to fail on TRT engine build errors (True) or not (False)
max_aux_streams (Optional[python:int]) – Maximum number of allowed auxiliary TRT streams for each engine
version_compatible (bool) – Provide version forward-compatibility for engine plan files
optimization_level (Optional[python:int]) – Builder optimization 0-5, higher levels imply longer build time, searching for more optimization options. TRT defaults to 3
use_python_runtime (Optional[bool]) – Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the argument as None
truncate_double (bool) – Whether to truncate float64 TRT engine inputs or weights to float32
use_fast_partitioner (bool) – Whether to use the fast or global graph partitioning system
enable_experimental_decompositions (bool) – Whether to enable all core aten decompositions or only a selected subset of them
device (Device) – GPU to compile the model on
require_full_compilation (bool) – Whether to require the graph is fully compiled in TensorRT. Only applicable for ir=”dynamo”; has no effect for torch.compile path
assume_dynamic_shape_support (bool) – Setting this to true enables the converters work for both dynamic and static shapes. Default: False
disable_tf32 (bool) – Whether to disable TF32 computation for TRT layers
sparse_weights (bool) – Whether to allow the builder to use sparse weights
refit (bool) – Whether to build a refittable engine
engine_capability (trt.EngineCapability) – Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (python:int) – Number of averaging timing iterations used to select kernels
dla_sram_size (python:int) – Fast software managed RAM used by DLA to communicate within a layer.
dla_local_dram_size (python:int) – Host RAM used by DLA to share intermediate tensor data across operations
dla_global_dram_size (python:int) – Host RAM used by DLA to store weights and metadata for execution
dryrun (Union[bool, str]) – Toggle “Dryrun” mode, which runs everything through partitioning, short of conversion to TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the output to a file if a string path is specified
hardware_compatible (bool) – Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str) – Path to the timing cache if it exists (or) where it will be saved after compilation
cache_built_engines (bool) – Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool) – Whether to load the compiled TRT engines from storage
use_strong_typing (bool) – This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool) – This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
enable_weight_streaming (bool) – Enable weight streaming.
enable_cross_compile_for_windows (bool) – By default this is False means TensorRT engines can only be executed on the same platform where they were built. True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
Custom Setting Usage¶
import torch_tensorrt
...
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False,
options={"truncate_long_and_double": True,
"enabled_precisions": {torch.float, torch.half},
"debug": True,
"min_block_size": 2,
"torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
"optimization_level": 4,
"use_python_runtime": False,})
Note
Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
Compilation¶
Compilation is triggered by passing inputs to the model, as so:
import torch_tensorrt
...
# Causes model compilation to occur
first_outputs = optimized_model(*inputs)
# Subsequent inference runs with the same, or similar inputs will not cause recompilation
# For a full discussion of this, see "Recompilation Conditions" below
second_outputs = optimized_model(*inputs)
After Compilation¶
The compilation object can be used for inference within the Python session, and will recompile according to the recompilation conditions detailed below. In addition to general inference, the compilation process can be a helpful tool in determining model performance, current operator coverage, and feasibility of serialization. Each of these points will be covered in detail below.
Model Performance¶
The optimized model returned from torch.compile is useful for model benchmarking since it can automatically handle changes in the compilation context, or differing inputs that could require recompilation. When benchmarking inputs of varying distributions, batch sizes, or other criteria, this can save time.
Operator Coverage¶
Compilation is also a useful tool in determining operator coverage for a particular model. For instance, the following compilation command will display the operator coverage for each graph, but will not compile the model - effectively providing a “dryrun” mechanism:
import torch_tensorrt
...
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False,
options={"debug": True,
"min_block_size": float("inf"),})
If key operators for your model are unsupported, see dynamo_conversion to contribute your own converters, or file an issue here: https://github.com/pytorch/TensorRT/issues.
Feasibility of Serialization¶
Compilation can also be helpful in demonstrating graph breaks and the feasibility of serialization of a particular model. For instance, if a model has no graph breaks and compiles successfully with the Torch-TensorRT backend, then that model should be compilable and serializeable via the torch_tensorrt Dynamo IR, as discussed in Dynamic shapes with Torch-TensorRT. To determine the number of graph breaks in a model, the torch._dynamo.explain function is very useful:
import torch
import torch_tensorrt
...
explanation = torch._dynamo.explain(model)(*inputs)
print(f"Graph breaks: {explanation.graph_break_count}")
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False, options={"truncate_long_and_double": True})
Dynamic Shape Support¶
The Torch-TensorRT torch.compile backend will currently require recompilation for each new batch size encountered, and it is preferred to use the dynamic=False argument when compiling with this backend. Full dynamic shape support is planned for a future release.
Recompilation Conditions¶
Once the model has been compiled, subsequent inference inputs with the same shape and data type, which traverse the graph in the same way will not require recompilation. Furthermore, each new recompilation will be cached for the duration of the Python session. For instance, if inputs of batch size 4 and 8 are provided to the model, causing two recompilations, no further recompilation would be necessary for future inputs with those batch sizes during inference within the same session. Support for engine cache serialization is planned for a future release.
Recompilation is generally triggered by one of two events: encountering inputs of different sizes or inputs which traverse the model code differently. The latter scenario can occur when the model code includes conditional logic, complex loops, or data-dependent-shapes. torch.compile handles guarding in both of these scenario and determines when recompilation is necessary.