Shortcuts

Source code for torch_tensorrt.runtime._multi_device_safe_mode

import logging
from typing import Any

import torch
import torch_tensorrt

if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
    _PY_RT_MULTI_DEVICE_SAFE_MODE = torch.ops.tensorrt.get_multi_device_safe_mode()
else:
    _PY_RT_MULTI_DEVICE_SAFE_MODE = False


logger = logging.getLogger(__name__)


class _MultiDeviceSafeModeContextManager(object):
    """Helper class used in conjunction with `set_multi_device_safe_mode`

    Used to enable `set_multi_device_safe_mode` as a dual-purpose context manager
    """

    def __init__(self, old_mode: bool) -> None:
        self.old_mode = old_mode

    def __enter__(self) -> "_MultiDeviceSafeModeContextManager":
        return self

    def __exit__(self, *args: Any) -> None:
        # Set multi-device safe mode back to old mode in Python
        global _PY_RT_MULTI_DEVICE_SAFE_MODE
        _PY_RT_MULTI_DEVICE_SAFE_MODE = self.old_mode

        # Set multi-device safe mode back to old mode in C++
        if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
            torch.ops.tensorrt.set_multi_device_safe_mode(self.old_mode)


[docs]def set_multi_device_safe_mode(mode: bool) -> _MultiDeviceSafeModeContextManager: """Sets the runtime (Python-only and default) into multi-device safe mode In the case that multiple devices are available on the system, in order for the runtime to execute safely, additional device checks are necessary. These checks can have a performance impact so they are therefore opt-in. Used to suppress the warning about running unsafely in a multi-device context. Arguments: mode (bool): Enable (``True``) or disable (``False``) multi-device checks Example: .. code-block:: py with torch_tensorrt.runtime.set_multi_device_safe_mode(True): results = trt_compiled_module(*inputs) """ # Fetch existing safe mode and set new mode for Python global _PY_RT_MULTI_DEVICE_SAFE_MODE old_mode = _PY_RT_MULTI_DEVICE_SAFE_MODE _PY_RT_MULTI_DEVICE_SAFE_MODE = mode # Set new mode for C++ if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: torch.ops.tensorrt.set_multi_device_safe_mode(mode) logger.info(f"Set multi-device safe mode to {mode}") # Return context manager in case the function is used in a `with` call return _MultiDeviceSafeModeContextManager(old_mode)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources