Shortcuts

Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule

from __future__ import annotations

import base64
import copy
import logging
import pickle
from typing import Any, List, Optional, Tuple, Union

import torch
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import Platform
from torch_tensorrt._features import (
    ENABLED_FEATURES,
    for_all_methods,
    needs_torch_tensorrt_runtime,
)
from torch_tensorrt.dynamo._settings import CompilationSettings

logger = logging.getLogger(__name__)

SerializedTensorRTEngineFmt = List[
    Union[str, bytes]
]  # Aligned with  //core/runtime/register_jit_hooks.cpp
SerializedTorchTensorRTModuleFmt = Tuple[
    str, Optional[SerializedTensorRTEngineFmt], List[str], List[str]
]

ABI_TARGET_IDX = -1  # Not implemented
NAME_IDX = -1  # Not implemented
DEVICE_IDX = -1  # Not implemented
ENGINE_IDX = -1  # Not implemented
INPUT_BINDING_NAMES_IDX = -1  # Not implemented
OUTPUT_BINDING_NAMES_IDX = -1  # Not implemented
HW_COMPATIBLE_IDX = -1  # Not implemented
SERIALIZED_METADATA_IDX = -1  # Not implemented
TARGET_PLATFORM_IDX = -1  # Not implemented
SERIALIZATION_LEN = -1  # Not implemented

if ENABLED_FEATURES.torch_tensorrt_runtime:
    ABI_TARGET_IDX = torch.ops.tensorrt.ABI_TARGET_IDX()  # 0
    NAME_IDX = torch.ops.tensorrt.NAME_IDX()  # 1
    DEVICE_IDX = torch.ops.tensorrt.DEVICE_IDX()  # 2
    ENGINE_IDX = torch.ops.tensorrt.ENGINE_IDX()  # 3
    INPUT_BINDING_NAMES_IDX = torch.ops.tensorrt.INPUT_BINDING_NAMES_IDX()  # 4
    OUTPUT_BINDING_NAMES_IDX = torch.ops.tensorrt.OUTPUT_BINDING_NAMES_IDX()  # 5
    HW_COMPATIBLE_IDX = torch.ops.tensorrt.HW_COMPATIBLE_IDX()  # 6
    SERIALIZED_METADATA_IDX = torch.ops.tensorrt.SERIALIZED_METADATA_IDX()  # 7
    TARGET_PLATFORM_IDX = torch.ops.tensorrt.TARGET_PLATFORM_IDX()  # 8
    SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN()  # 9


[docs]@for_all_methods(needs_torch_tensorrt_runtime) class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] """TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. This module is backed by the Torch-TensorRT runtime and is fully compatible with both FX / Python deployments (just ``import torch_tensorrt`` as part of the application) as well as TorchScript / C++ deployments since TorchTensorRTModule can be passed to ``torch.jit.trace`` and then saved. The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where the internal implementation is ``return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))`` > Note: TorchTensorRTModule only supports engines built with explicit batch Attributes: name (str): Name of module (for easier debugging) engine (torch.classes.tensorrt.Engine): Torch-TensorRT TensorRT Engine instance, manages [de]serialization, device configuration, profiling input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned """ def __init__( self, serialized_engine: Optional[bytes] = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, *, name: str = "", settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed weight_name_map: Optional[dict[Any, Any]] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines If binding names are not provided, it is assumed that the engine binding names follow the following convention: - [symbol].[index in input / output array] - ex. [x.0, x.1, x.2] -> [y.0] Arguments: serialized_engine (bytes): Serialized TensorRT engine in the form of a bytearray input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned Keyword Arguments: name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed Example: .. code-block:: py with io.BytesIO() as engine_bytes: engine_bytes.write(trt_engine.serialize()) engine_str = engine_bytes.getvalue() trt_module = TorchTensorRTModule( engine_str, input_binding_names=["x"], output_binding_names=["output"], name="my_module", settings=CompilationSettings(device=torch.cuda.current_device) ) """ super(TorchTensorRTModule, self).__init__() if not isinstance(serialized_engine, bytearray): ValueError("Expected serialized engine as bytearray") self.input_binding_names = ( input_binding_names if input_binding_names is not None else [] ) self.output_binding_names = ( output_binding_names if output_binding_names is not None else [] ) self.name = name self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) self.weight_name_map = weight_name_map self.serialized_engine = serialized_engine self.engine = None if ( serialized_engine and not self.settings.lazy_engine_init and not self.settings.enable_cross_compile_for_windows ): self.setup_engine() def _pack_engine_info(self) -> List[str | bytes]: target_device = ( self.settings.device if self.settings.device is not None else Device._current_device() ) metadata = {"settings": self.settings, "weight_name_map": self.weight_name_map} target_platform = ( Platform.current_platform() if not self.settings.enable_cross_compile_for_windows else Platform.WIN_X86_64 ) # Change to match target for engine engine_info: List[str | bytes] = [""] * SERIALIZATION_LEN engine_info[ABI_TARGET_IDX] = torch.ops.tensorrt.ABI_VERSION() engine_info[NAME_IDX] = ( self.name + "_engine" if self.name != "" else "tensorrt_engine" ) engine_info[DEVICE_IDX] = target_device._to_serialized_rt_device() assert self.serialized_engine engine_info[ENGINE_IDX] = self.serialized_engine engine_info[INPUT_BINDING_NAMES_IDX] = TorchTensorRTModule._pack_binding_names( self.input_binding_names ) engine_info[OUTPUT_BINDING_NAMES_IDX] = TorchTensorRTModule._pack_binding_names( self.output_binding_names ) engine_info[HW_COMPATIBLE_IDX] = str(int(self.hardware_compatible)) engine_info[SERIALIZED_METADATA_IDX] = self.encode_metadata(metadata) engine_info[TARGET_PLATFORM_IDX] = target_platform._to_serialized_rt_platform() return engine_info def get_streamable_device_memory_budget(self) -> Any: return self.engine.streamable_device_memory_budget def get_automatic_device_memory_budget(self) -> Any: return self.engine.automatic_device_memory_budget def get_device_memory_budget(self) -> Any: return self.engine.device_memory_budget def set_device_memory_budget(self, budget_bytes: int) -> int: # Disable weight streaming for invalid budget size if budget_bytes < 0: budget_bytes = self.get_streamable_device_memory_budget() self.engine.device_memory_budget = budget_bytes if self.engine.device_memory_budget != budget_bytes: logger.error(f"Failed to set weight streaming budget to {budget_bytes}") budget_bytes = self.engine.device_memory_budget if self.get_streamable_device_memory_budget() == budget_bytes: logger.warning("Weight streaming is disabled") return budget_bytes def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. Will setup the TensorRT engine for this module in the case that setup has been deferred. In the case that the engine has already been setup, will return without changing anything. Assumes that serialized engine and settings have already been passed to the module. """ if self.engine is not None: return self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) def encode_metadata(self, metadata: Any) -> str: metadata = copy.deepcopy(metadata) dumped_metadata = pickle.dumps(metadata) encoded_metadata = base64.b64encode(dumped_metadata).decode("utf-8") return encoded_metadata @staticmethod def decode_metadata(encoded_metadata: bytes) -> Any: dumped_metadata = base64.b64decode(encoded_metadata.encode("utf-8")) metadata = pickle.loads(dumped_metadata) return metadata def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: if self.engine: return ( self.name, self.engine.__getstate__(), self.input_binding_names, self.output_binding_names, ) elif self.serialized_engine: engine_info = self._pack_engine_info() assert isinstance(engine_info[3], bytes) engine_info[ENGINE_IDX] = base64.b64encode(engine_info[3]) return ( self.name, engine_info, self.input_binding_names, self.output_binding_names, ) else: return ( self.name, None, self.input_binding_names, self.output_binding_names, ) def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.name = state[0] if state[1] is not None: serialized_engine_info: SerializedTensorRTEngineFmt = state[1] serialized_engine_info[ENGINE_IDX] = base64.b64decode( serialized_engine_info[ENGINE_IDX] ) self.engine = torch.classes.tensorrt.Engine(serialized_engine_info) self.hardware_compatible = bool(int(state[1][HW_COMPATIBLE_IDX])) serialized_metadata = serialized_engine_info[SERIALIZED_METADATA_IDX] assert isinstance(serialized_metadata, bytes) metadata = TorchTensorRTModule.decode_metadata(serialized_metadata) self.settings = metadata["settings"] self.weight_name_map = metadata["weight_name_map"] else: self.engine = None self.settings = CompilationSettings() self.hardware_compatible = False self.input_binding_names = state[2] self.output_binding_names = state[3] def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """Implementation of the forward pass for a TensorRT engine Args: *inputs (Union[torch.Tensor, int]): Inputs to the forward function Returns: torch.Tensor or Tuple(torch.Tensor): Result of the engine computation """ if self.engine is None: raise RuntimeError("Engine has not been setup yet.") assert len(inputs) == len( self.input_binding_names ), f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}." # If the inputs are not Torch Tensors, which can occur in scenarios such as shape tensors # which are outputs of a preceding Torch subgraph (where the Dynamic input may be an integer) # directly cast the input to a Torch Tensor. # # This also avoids the need for type-checking inputs, since they are now explicitly casted to Torch tensors input_tensors: List[torch.Tensor] = [ (i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) for i in inputs ] outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine( list(input_tensors), self.engine ) if len(outputs) == 1: return outputs[0] return tuple(outputs) def enable_profiling(self, profiling_results_dir: Optional[str] = None) -> None: """Enable the profiler to collect latency information about the execution of the engine Traces can be visualized using https://ui.perfetto.dev/ or compatible alternatives Keyword Arguments: profiling_results_dir (str): Absolute path to the directory to sort results of profiling. """ if self.engine is None: raise RuntimeError("Engine has not been initialized yet.") if profiling_results_dir is not None: self.engine.profile_path_prefix = profiling_results_dir self.engine.enable_profiling() def disable_profiling(self) -> None: """Disable the profiler""" if self.engine is None: raise RuntimeError("Engine has not been initialized yet.") self.engine.disable_profiling() def get_layer_info(self) -> str: """Get a JSON string containing the layer information encoded by the TensorRT engine in this module Returns: str: A JSON string which contains the layer information of the engine incapsulated in this module """ if self.engine is None: raise RuntimeError("Engine has not been initialized yet.") layer_info: str = self.engine.get_engine_layer_info() return layer_info def dump_layer_info(self) -> None: """Dump layer information encoded by the TensorRT engine in this module to STDOUT""" if self.engine is None: raise RuntimeError("Engine has not been initialized yet.") self.engine.dump_engine_layer_info() @staticmethod def _pack_binding_names(binding_names: List[str]) -> str: delim = torch.ops.tensorrt.SERIALIZED_ENGINE_BINDING_DELIM()[0] packed_bindings: str = delim.join(binding_names) return packed_bindings

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources