Shortcuts

torch_tensorrt.dynamo

Functions

torch_tensorrt.dynamo.compile(exported_program: ~torch.export.exported_program.ExportedProgram, inputs: ~typing.Tuple[~typing.Any, ...], *, device: ~typing.Optional[~typing.Union[~torch_tensorrt._Device.Device, ~torch.device, str]] = None, disable_tf32: bool = False, sparse_weights: bool = False, enabled_precisions: ~typing.Union[~typing.Set[~torch.dtype], ~typing.Tuple[~torch.dtype]] = (torch.float32,), engine_capability: ~torch_tensorrt._C.EngineCapability = <EngineCapability.DEFAULT: 0>, refit: bool = False, debug: bool = False, capability: ~torch_tensorrt._C.EngineCapability = <EngineCapability.default: 0>, num_avg_timing_iters: int = 1, workspace_size: int = 0, dla_sram_size: int = 1048576, dla_local_dram_size: int = 1073741824, dla_global_dram_size: int = 536870912, calibrator: object = None, truncate_long_and_double: bool = False, require_full_compilation: bool = False, min_block_size: int = 5, torch_executed_ops: ~typing.Optional[~typing.Collection[~typing.Union[~typing.Callable[[...], ~typing.Any], str]]] = None, torch_executed_modules: ~typing.Optional[~typing.List[str]] = None, pass_through_build_failures: bool = False, max_aux_streams: ~typing.Optional[int] = None, version_compatible: bool = False, optimization_level: ~typing.Optional[int] = None, use_python_runtime: bool = False, use_fast_partitioner: bool = True, enable_experimental_decompositions: bool = False, dryrun: bool = False, hardware_compatible: bool = False, **kwargs: ~typing.Any) GraphModule[source]

Compile a TorchScript module for NVIDIA GPUs using TensorRT

Takes a existing TorchScript module and a set of settings to configure the compiler and will convert methods to JIT Graphs which call equivalent TensorRT engines

Converts specifically the forward method of a TorchScript Module

Parameters
  • exported_program (torch.export.ExportedProgram) – Source module, running torch.export on a torch.nn.Module

  • inputs (Tuple[Any, ...]) –

    List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.

    input=[
        torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
        torch_tensorrt.Input(
            min_shape=(1, 224, 224, 3),
            opt_shape=(1, 512, 512, 3),
            max_shape=(1, 1024, 1024, 3),
            dtype=torch.int32
            format=torch.channel_last
        ), # Dynamic input shape for input #2
        torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
    ]
    

Keyword Arguments
  • device (Union(Device, torch.device, dict)) –

    Target device for TensorRT engines to run on

    device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
    

  • disable_tf32 (bool) – Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas

  • sparse_weights (bool) – Enable sparsity for convolution and fully connected layers.

  • enabled_precision (Set(Union(torch.dpython:type, torch_tensorrt.dpython:type))) – The set of datatypes that TensorRT can use when selecting kernels

  • refit (bool) – Enable refitting

  • debug (bool) – Enable debuggable engine

  • capability (EngineCapability) – Restrict kernel selection to safe gpu kernels or safe dla kernels

  • num_avg_timing_iters (python:int) – Number of averaging timing iterations used to select kernels

  • workspace_size (python:int) – Maximum size of workspace given to TensorRT

  • dla_sram_size (python:int) – Fast software managed RAM used by DLA to communicate within a layer.

  • dla_local_dram_size (python:int) – Host RAM used by DLA to share intermediate tensor data across operations

  • dla_global_dram_size (python:int) – Host RAM used by DLA to store weights and metadata for execution

  • truncate_long_and_double (bool) – Truncate weights provided in int64 or double (float64) to int32 and float32

  • calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)) – Calibrator object which will provide data to the PTQ system for INT8 Calibration

  • require_full_compilation (bool) – Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch

  • min_block_size (python:int) – The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT

  • torch_executed_ops (Collection[Target]) – Set of aten operators that must be run in PyTorch. An error will be thrown if this set is not empty but require_full_compilation is True

  • torch_executed_modules (List[str]) – List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but require_full_compilation is True

  • pass_through_build_failures (bool) – Error out if there are issues during compilation (only applicable to torch.compile workflows)

  • max_aux_stream (Optional[python:int]) – Maximum streams in the engine

  • version_compatible (bool) – Build the TensorRT engines compatible with future versions of TensorRT (Restrict to lean runtime operators to provide version forward compatibility for the engines)

  • optimization_level – (Optional[int]): Setting a higher optimization level allows TensorRT to spend longer engine building time searching for more optimization options. The resulting engine may have better performance compared to an engine built with a lower optimization level. The default optimization level is 3. Valid values include integers from 0 to the maximum optimization level, which is currently 5. Setting it to be greater than the maximum level results in identical behavior to the maximum level.

  • use_python_runtime – (bool): Return a graph using a pure Python runtime, reduces options for serialization

  • use_fast_partitioner – (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (False) if looking for best performance

  • enable_experimental_decompositions (bool) – Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.

  • dryrun (bool) – Toggle for “Dryrun” mode, running everything except conversion to TRT and logging outputs

  • hardware_compatible (bool) – Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)

  • **kwargs – Any,

Returns

Compiled FX Module, when run it will execute via TensorRT

Return type

torch.fx.GraphModule

torch_tensorrt.dynamo.trace(mod: torch.nn.modules.module.Module | torch.fx.graph_module.GraphModule, inputs: Tuple[Any, ...], **kwargs: Any) ExportedProgram[source]

Exports a torch.export.ExportedProgram from a torch.nn.Module or torch.fx.GraphModule specifically targeting being compiled with Torch-TensorRT

Exports a torch.export.ExportedProgram from either a torch.nn.Module or torch.fx.GraphModule``. Runs specific operator decompositions geared towards compilation by Torch-TensorRT’s dynamo frontend.

Parameters
  • mod (torch.nn.Module | torch.fx.GraphModule) – Source module to later be compiled by Torch-TensorRT’s dynamo fronted

  • inputs (Tuple[Any, ...]) –

    List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.

    input=[
        torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
        torch_tensorrt.Input(
            min_shape=(1, 224, 224, 3),
            opt_shape=(1, 512, 512, 3),
            max_shape=(1, 1024, 1024, 3),
            dtype=torch.int32
            format=torch.channel_last
        ), # Dynamic input shape for input #2
        torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
    ]
    

Keyword Arguments
  • device (Union(torch.device, dict)) –

    Target device for TensorRT engines to run on

    device=torch.device("cuda:0")
    

  • debug (bool) – Enable debuggable engine

  • enable_experimental_decompositions (bool) – Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.

  • **kwargs – Any,

Returns

Compiled FX Module, when run it will execute via TensorRT

Return type

torch.fx.GraphModule

torch_tensorrt.dynamo.export(gm: GraphModule, inputs: Sequence[Tensor], *, ir: str = 'torchscript') ExportedProgram[source]

Export a program (torch.fx.GraphModule) for serialization with the TensorRT engines embedded.

> Note: When ExportedProgram becomes stable, this function will get merged into torch_tensorrt.dynamo.compile

Parameters
  • src_gm (torch.fx.GraphModule) – Source module, generated by torch.export (The module provided to torch_tensorrt.dynamo.compile)

  • gm (torch.fx.GraphModule) – Compiled Torch-TensorRT module, generated by torch_tensorrt.dynamo.compile

Keyword Arguments
  • inputs (Any) –

    Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.

    input=[
        torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
        torch_tensorrt.Input(
            min_shape=(1, 224, 224, 3),
            opt_shape=(1, 512, 512, 3),
            max_shape=(1, 1024, 1024, 3),
            dtype=torch.int32
            format=torch.channel_last
        ), # Dynamic input shape for input #2
        torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
    

  • ir (str) – torchscript | exported_program. Based on the provided ir, the output type would be a torchscript or exported program.

Classes

class torch_tensorrt.dynamo.CompilationSettings(precision: ~torch.dtype = torch.float32, debug: bool = False, workspace_size: int = 0, min_block_size: int = 5, torch_executed_ops: ~typing.Collection[~typing.Union[~typing.Callable[[...], ~typing.Any], str]] = <factory>, pass_through_build_failures: bool = False, max_aux_streams: ~typing.Optional[int] = None, version_compatible: bool = False, optimization_level: ~typing.Optional[int] = None, use_python_runtime: ~typing.Optional[bool] = False, truncate_long_and_double: bool = False, use_fast_partitioner: bool = True, enable_experimental_decompositions: bool = False, device: ~torch_tensorrt._Device.Device = <factory>, require_full_compilation: bool = False, disable_tf32: bool = False, sparse_weights: bool = False, refit: bool = False, engine_capability: ~tensorrt_bindings.tensorrt.EngineCapability = <EngineCapability.DEFAULT: 0>, num_avg_timing_iters: int = 1, dla_sram_size: int = 1048576, dla_local_dram_size: int = 1073741824, dla_global_dram_size: int = 536870912, dryrun: ~typing.Union[bool, str] = False, hardware_compatible: bool = False)[source]

Compilation settings for Torch-TensorRT Dynamo Paths

Parameters
  • precision (torch.dpython:type) – Model Layer precision

  • debug (bool) – Whether to print out verbose debugging information

  • workspace_size (python:int) – Workspace TRT is allowed to use for the module (0 is default)

  • min_block_size (python:int) – Minimum number of operators per TRT-Engine Block

  • torch_executed_ops (Collection[Target]) – Collection of operations to run in Torch, regardless of converter coverage

  • pass_through_build_failures (bool) – Whether to fail on TRT engine build errors (True) or not (False)

  • max_aux_streams (Optional[python:int]) – Maximum number of allowed auxiliary TRT streams for each engine

  • version_compatible (bool) – Provide version forward-compatibility for engine plan files

  • optimization_level (Optional[python:int]) – Builder optimization 0-5, higher levels imply longer build time, searching for more optimization options. TRT defaults to 3

  • use_python_runtime (Optional[bool]) – Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the argument as None

  • truncate_long_and_double (bool) – Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32

  • use_fast_partitioner (bool) – Whether to use the fast or global graph partitioning system

  • enable_experimental_decompositions (bool) – Whether to enable all core aten decompositions or only a selected subset of them

  • device (Device) – GPU to compile the model on

  • require_full_compilation (bool) – Whether to require the graph is fully compiled in TensorRT. Only applicable for ir=”dynamo”; has no effect for torch.compile path

  • disable_tf32 (bool) – Whether to disable TF32 computation for TRT layers

  • sparse_weights (bool) – Whether to allow the builder to use sparse weights

  • refit (bool) – Whether to build a refittable engine

  • engine_capability (trt.EngineCapability) – Restrict kernel selection to safe gpu kernels or safe dla kernels

  • num_avg_timing_iters (python:int) – Number of averaging timing iterations used to select kernels

  • dla_sram_size (python:int) – Fast software managed RAM used by DLA to communicate within a layer.

  • dla_local_dram_size (python:int) – Host RAM used by DLA to share intermediate tensor data across operations

  • dla_global_dram_size (python:int) – Host RAM used by DLA to store weights and metadata for execution

  • dryrun (Union[bool, str]) – Toggle “Dryrun” mode, which runs everything through partitioning, short of conversion to TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the ouptut to a file if a string path is specified

  • hardware_compatible (bool) – Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)

class torch_tensorrt.dynamo.SourceIR(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources