Source code for torch_tensorrt.runtime._pre_allocated_outputs
import logging
from typing import Any
import torch
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
logger = logging.getLogger(__name__)
class _PreAllocatedOutputContextManager(object):
"""
Helper class used to enable pre-allocated output feature in runtime module
"""
def __init__(self, module: torch.fx.GraphModule) -> None:
rt_mods = []
for name, rt_mod in module.named_children():
if "_run_on_acc" in name and isinstance(
rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)
):
rt_mods.append(rt_mod)
self.rt_mods = rt_mods
def set_pre_allocated_output(self, enable: bool) -> None:
for mod in self.rt_mods:
mod.set_pre_allocated_outputs(enable)
def __enter__(self) -> "_PreAllocatedOutputContextManager":
# Enable pre-allocated output
self.set_pre_allocated_output(True)
return self
def __exit__(self, *args: Any) -> None:
# Disable pre-allocated output
self.set_pre_allocated_output(False)
[docs]def enable_pre_allocated_outputs(
module: torch.fx.GraphModule,
) -> _PreAllocatedOutputContextManager:
return _PreAllocatedOutputContextManager(module)