Shortcuts

Source code for torch_tensorrt.runtime._output_allocator

import logging
from typing import Any, Union

import torch
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
    CudaGraphsTorchTensorRTModule,
)

logger = logging.getLogger(__name__)


class _OutputAllocatorContextManager(object):
    """
    Helper class to set up output_allocator
    """

    def __init__(
        self, module: Union[torch.fx.GraphModule, CudaGraphsTorchTensorRTModule]
    ) -> None:
        if isinstance(module, CudaGraphsTorchTensorRTModule):
            rt_mods = [module]
        else:
            rt_mods = []

            for name, rt_mod in module.named_children():
                if "_run_on_acc" in name and isinstance(
                    rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)
                ):
                    rt_mods.append(rt_mod)

        self.rt_mods = rt_mods

    def set_output_allocator_output(self, enable: bool) -> None:
        for mod in self.rt_mods:
            mod.set_use_output_allocator(enable)

    def __enter__(self) -> "_OutputAllocatorContextManager":
        # Enable output_allocator for TRT submodules
        self.set_output_allocator_output(True)
        return self

    def __exit__(self, *args: Any) -> None:
        # Disable output_allocator
        self.set_output_allocator_output(False)


[docs]def enable_output_allocator( module: torch.fx.GraphModule, ) -> _OutputAllocatorContextManager: return _OutputAllocatorContextManager(module)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources