Shortcuts

Source code for torch_tensorrt.runtime._weight_streaming

import logging
from typing import Any, Union

import torch
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
    CudaGraphsTorchTensorRTModule,
)

logger = logging.getLogger(__name__)


class _WeightStreamingContextManager(object):
    """
    Helper class used to setup weight streaming budget
    """

    def __init__(
        self, module: Union[torch.fx.GraphModule, CudaGraphsTorchTensorRTModule]
    ) -> None:
        rt_mods = []
        self.current_device_budget = 0
        self.cuda_graphs_module = None

        if isinstance(module, CudaGraphsTorchTensorRTModule):
            self.cuda_graphs_module = module
            module = module.compiled_module
        for name, rt_mod in module.named_children():
            if "_run_on_acc" in name and isinstance(
                rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)
            ):
                rt_mods.append((name, rt_mod))
                self.current_device_budget += rt_mod.get_device_memory_budget()
        self.streamable_budget = [
            mod.get_streamable_device_memory_budget() for _, mod in rt_mods
        ]
        self.rt_mods = rt_mods
        total_device_budget = sum(self.streamable_budget)
        super().__setattr__("device_budget", self.current_device_budget)
        super().__setattr__("total_device_budget", total_device_budget)

    def get_automatic_weight_streaming_budget(self) -> int:
        ws_budget_bytes = 0
        for _, rt_mod in self.rt_mods:
            ws_budget_bytes += rt_mod.get_automatic_device_memory_budget()
        return ws_budget_bytes

    def __enter__(self) -> "_WeightStreamingContextManager":
        return self

    def __exit__(self, *args: Any) -> None:
        if self.total_device_budget > 0:
            logger.debug(
                f"Revert weight streaming budget to initial size: {self.current_device_budget}"
            )
            self.device_budget = self.current_device_budget

    def _set_streamable_weight_bytes(self, requested_budget: int) -> int:
        ws_budget_bytes = 0
        total_bytes = self.total_device_budget
        if total_bytes == 0:
            raise RuntimeError(
                "Streamable bytes are zero. Was module complied with enable_weight_streaming=True option?"
            )
        elif total_bytes < requested_budget:
            logger.error(
                f"Requested budget is greater than streamable bytes: {total_bytes}. requested budget: {requested_budget}"
            )
            requested_budget = total_bytes
        elif requested_budget < 0:
            raise RuntimeError("Requested budget cannot be negative")
        # Normalized size is applied for multiple runtime module.
        # e.g. 100B budget is applied to two modules and they have 1000B and 3000B streamable size respectively.
        # Then 25B and 75B are applied for each module.
        normalized_size = [
            int(streamable_bytes / total_bytes * requested_budget)
            for streamable_bytes in self.streamable_budget
        ]
        for i, (name, rt_mod) in enumerate(self.rt_mods):
            ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i])
            logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}")

        if self.cuda_graphs_module:
            self.cuda_graphs_module.is_weight_streaming_set = True
        return ws_budget_bytes

    def __setattr__(self, name: str, value: Any) -> None:
        if name == "device_budget":
            value = self._set_streamable_weight_bytes(value)
        super().__setattr__(name, value)


[docs]def weight_streaming( module: torch.fx.GraphModule, ) -> _WeightStreamingContextManager: return _WeightStreamingContextManager(module)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources