Shortcuts

torch_tensorrt.runtime

Functions

torch_tensorrt.runtime.set_multi_device_safe_mode(mode: bool) _MultiDeviceSafeModeContextManager[source]

Sets the runtime (Python-only and default) into multi-device safe mode

In the case that multiple devices are available on the system, in order for the runtime to execute safely, additional device checks are necessary. These checks can have a performance impact so they are therefore opt-in. Used to suppress the warning about running unsafely in a multi-device context.

Parameters

mode (bool) – Enable (True) or disable (False) multi-device checks

Example

with torch_tensorrt.runtime.set_multi_device_safe_mode(True):
    results = trt_compiled_module(*inputs)

Classes

class torch_tensorrt.runtime.TorchTensorRTModule(**kwargs: Dict[str, Any])[source]

TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.

This module is backed by the Torch-TensorRT runtime and is fully compatible with both FX / Python deployments (just import torch_tensorrt as part of the application) as well as TorchScript / C++ deployments since TorchTensorRTModule can be passed to torch.jit.trace and then saved.

The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where the internal implementation is return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))

> Note: TorchTensorRTModule only supports engines built with explicit batch

Variables
  • name (str) – Name of module (for easier debugging)

  • engine (torch.classes.tensorrt.Engine) – Torch-TensorRT TensorRT Engine instance, manages [de]serialization, device configuration, profiling

  • input_binding_names (List[str]) – List of input TensorRT engine binding names in the order they would be passed to the TRT modules

  • output_binding_names (List[str]) – List of output TensorRT engine binding names in the order they should be returned

__init__(**kwargs: Dict[str, Any]) Any

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(**kwargs: Dict[str, Any]) Any

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_extra_state(**kwargs: Dict[str, Any]) Any

Return any extra state to include in the module’s state_dict.

Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns

Any extra state to store in the module’s state_dict

Return type

object

set_extra_state(**kwargs: Dict[str, Any]) Any

Set extra state contained in the loaded state_dict.

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state() for your module if you need to store extra state within its state_dict.

Parameters

state (dict) – Extra state from the state_dict

class torch_tensorrt.runtime.PythonTorchTensorRTModule(serialized_engine: ~typing.Optional[bytes] = None, input_binding_names: ~typing.Optional[~typing.List[str]] = None, output_binding_names: ~typing.Optional[~typing.List[str]] = None, *, name: str = '', settings: ~torch_tensorrt.dynamo._settings.CompilationSettings = CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True), weight_name_map: ~typing.Any = None)[source]

PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.

This module is backed by the Torch-TensorRT runtime and is only compatible with FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment.

__init__(serialized_engine: ~typing.Optional[bytes] = None, input_binding_names: ~typing.Optional[~typing.List[str]] = None, output_binding_names: ~typing.Optional[~typing.List[str]] = None, *, name: str = '', settings: ~torch_tensorrt.dynamo._settings.CompilationSettings = CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True), weight_name_map: ~typing.Any = None)[source]

Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch torch.nn.Module around it. Uses TensorRT Python APIs to run the engine

Parameters
  • serialized_engine (bytes) – Serialized TensorRT engine in the form of a bytearray

  • input_binding_names (List[str]) – List of input TensorRT engine binding names in the order they would be passed to the TRT modules

  • output_binding_names (List[str]) – List of output TensorRT engine binding names in the order they should be returned

Keyword Arguments
  • name (str) – Name for module

  • settings (CompilationSettings) – Settings used to compile engine, assumes engine was built with default compilation settings if object not passed

Example

trt_module = PythonTorchTensorRTModule(
    engine_str,
    input_binding_names=["x"],
    output_binding_names=["output"],
    name="my_module",
    settings=CompilationSettings(device=torch.cuda.current_device)
)
cudagraphs_validate_shapes(inputs: Sequence[Tensor]) bool[source]

Validates the input shapes of the forward function versus the version currently active for the

disable_profiling() None[source]

Disable TensorRT profiling.

enable_profiling(profiler: IProfiler = None) None[source]

Enable TensorRT profiling. After calling this function, TensorRT will report time spent on each layer in stdout for each forward run.

forward(*inputs: Tensor) Union[Tensor, Tuple[Tensor, ...]][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_layer_info() str[source]

Get layer info of the engine. Only support for TRT > 8.2.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources