Shortcuts

Source code for torch_tensorrt.runtime._pre_allocated_outputs

import logging
from typing import Any

import torch
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule

logger = logging.getLogger(__name__)


class _PreAllocatedOutputContextManager(object):
    """
    Helper class used to enable pre-allocated output feature in runtime module
    """

    def __init__(self, module: torch.fx.GraphModule) -> None:
        rt_mods = []
        for name, rt_mod in module.named_children():
            if "_run_on_acc" in name and isinstance(
                rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)
            ):
                rt_mods.append(rt_mod)
        self.rt_mods = rt_mods

    def set_pre_allocated_output(self, enable: bool) -> None:
        for mod in self.rt_mods:
            mod.set_pre_allocated_outputs(enable)

    def __enter__(self) -> "_PreAllocatedOutputContextManager":
        # Enable pre-allocated output
        self.set_pre_allocated_output(True)
        return self

    def __exit__(self, *args: Any) -> None:
        # Disable pre-allocated output
        self.set_pre_allocated_output(False)


[docs]def enable_pre_allocated_outputs( module: torch.fx.GraphModule, ) -> _PreAllocatedOutputContextManager: return _PreAllocatedOutputContextManager(module)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources