torch_tensorrt.runtime¶
Functions¶
- torch_tensorrt.runtime.set_multi_device_safe_mode(mode: bool) _MultiDeviceSafeModeContextManager [source]¶
Sets the runtime (Python-only and default) into multi-device safe mode
In the case that multiple devices are available on the system, in order for the runtime to execute safely, additional device checks are necessary. These checks can have a performance impact so they are therefore opt-in. Used to suppress the warning about running unsafely in a multi-device context.
- Parameters
mode (bool) – Enable (
True
) or disable (False
) multi-device checks
Example
with torch_tensorrt.runtime.set_multi_device_safe_mode(True): results = trt_compiled_module(*inputs)
Classes¶
- class torch_tensorrt.runtime.TorchTensorRTModule(**kwargs: Dict[str, Any])[source]¶
TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
This module is backed by the Torch-TensorRT runtime and is fully compatible with both FX / Python deployments (just
import torch_tensorrt
as part of the application) as well as TorchScript / C++ deployments since TorchTensorRTModule can be passed totorch.jit.trace
and then saved.The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where the internal implementation is
return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))
> Note: TorchTensorRTModule only supports engines built with explicit batch
- Variables
name (str) – Name of module (for easier debugging)
engine (torch.classes.tensorrt.Engine) – Torch-TensorRT TensorRT Engine instance, manages [de]serialization, device configuration, profiling
input_binding_names (List[str]) – List of input TensorRT engine binding names in the order they would be passed to the TRT modules
output_binding_names (List[str]) – List of output TensorRT engine binding names in the order they should be returned
- __init__(**kwargs: Dict[str, Any]) Any ¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(**kwargs: Dict[str, Any]) Any ¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_extra_state(**kwargs: Dict[str, Any]) Any ¶
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding
set_extra_state()
for your module if you need to store extra state. This function is called when building the module’s state_dict().Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.
- Returns
Any extra state to store in the module’s state_dict
- Return type
object
- set_extra_state(**kwargs: Dict[str, Any]) Any ¶
Set extra state contained in the loaded state_dict.
This function is called from
load_state_dict()
to handle any extra state found within the state_dict. Implement this function and a correspondingget_extra_state()
for your module if you need to store extra state within its state_dict.- Parameters
state (dict) – Extra state from the state_dict
- class torch_tensorrt.runtime.PythonTorchTensorRTModule(serialized_engine: ~typing.Optional[bytes] = None, input_binding_names: ~typing.Optional[~typing.List[str]] = None, output_binding_names: ~typing.Optional[~typing.List[str]] = None, *, name: str = '', settings: ~torch_tensorrt.dynamo._settings.CompilationSettings = CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, refit_identical_engine_weights=False, strip_engine_weights=False, immutable_weights=True, enable_weight_streaming=False, enable_cross_compile_for_windows=False), weight_name_map: ~typing.Optional[dict[typing.Any, typing.Any]] = None)[source]¶
PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
This module is backed by the Torch-TensorRT runtime and is only compatible with FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment.
- __init__(serialized_engine: ~typing.Optional[bytes] = None, input_binding_names: ~typing.Optional[~typing.List[str]] = None, output_binding_names: ~typing.Optional[~typing.List[str]] = None, *, name: str = '', settings: ~torch_tensorrt.dynamo._settings.CompilationSettings = CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, refit_identical_engine_weights=False, strip_engine_weights=False, immutable_weights=True, enable_weight_streaming=False, enable_cross_compile_for_windows=False), weight_name_map: ~typing.Optional[dict[typing.Any, typing.Any]] = None)[source]¶
Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch
torch.nn.Module
around it. Uses TensorRT Python APIs to run the engine- Parameters
serialized_engine (bytes) – Serialized TensorRT engine in the form of a bytearray
input_binding_names (List[str]) – List of input TensorRT engine binding names in the order they would be passed to the TRT modules
output_binding_names (List[str]) – List of output TensorRT engine binding names in the order they should be returned
- Keyword Arguments
name (str) – Name for module
settings (CompilationSettings) – Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
weight_name_map (dict) – Mapping of engine weight name to state_dict weight name
Example
trt_module = PythonTorchTensorRTModule( engine_str, input_binding_names=["x"], output_binding_names=["output"], name="my_module", settings=CompilationSettings(device=torch.cuda.current_device) )
- enable_profiling(profiler: IProfiler = None) None [source]¶
Enable TensorRT profiling. After calling this function, TensorRT will report time spent on each layer in stdout for each forward run.
- forward(*inputs: Tensor) Union[Tensor, Tuple[Tensor, ...]] [source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.