Source code for torch.utils.module_tracker
# mypy: allow-untyped-defs
import logging
import weakref
from typing import Set
import torch
from torch.autograd.graph import register_multi_grad_hook
from torch.nn.modules.module import (
register_module_forward_hook,
register_module_forward_pre_hook,
)
from torch.utils._pytree import tree_flatten
logger = logging.getLogger(__name__)
__all__ = ["ModuleTracker"]
[docs]class ModuleTracker:
"""
``ModuleTracker`` is a context manager that tracks the nn.Module hierarchy during execution
so that other system can query which Module is currently being executed (or its backward is being
executed).
You can access the ``parents`` attribute on this context manager to get the set of all the
Modules currently being executed via their fqn (fully qualified name, also used as the key within
the state_dict).
You can access the ``is_bw`` attribute to know if you are currently running in backward or not.
Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag
will remain ``True`` after the forward until another Module is executed. If you need it to be
more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance
is possible but not done yet, please submit an issue requesting this if you need it.
Example usage
.. code-block:: python
mod = torch.nn.Linear(2, 2)
with ModuleTracker() as tracker:
# Access anything during the forward pass
def my_linear(m1, m2, bias):
print(f"Current modules: {tracker.parents}")
return torch.mm(m1, m2.t()) + bias
torch.nn.functional.linear = my_linear
mod(torch.rand(2, 2))
"""
parents: Set[str]
"""
A Set containing the fqn for each module currently running their forward
"""
def __init__(self):
self.parents = {"Global"}
self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self._seen_modules: weakref.WeakSet = weakref.WeakSet()
self._has_callback = False
def _maybe_set_engine_callback(self):
# This assumes no concurrent calls to backward
if self._has_callback:
return
def callback():
self.parents = {"Global"}
self._has_callback = False
torch.autograd.Variable._execution_engine.queue_callback(callback)
self._has_callback = True
@property
def is_bw(self):
"""
A boolean marking if this is currently running during the backward pass or not
"""
return torch._C._current_graph_task_id() != -1
def _get_mod_name(self, mod):
if mod not in self._known_modules:
self._known_modules[mod] = type(mod).__name__
mod_name = self._known_modules[mod]
if mod not in self._seen_modules:
for name, submod in mod.named_children():
self._known_modules[submod] = f"{mod_name}.{name}"
self._get_mod_name(submod)
self._seen_modules.add(mod)
return mod_name
def _get_append_fn(self, name, is_bw):
def fn(*args):
if is_bw:
self._maybe_set_engine_callback()
if name in self.parents:
logger.info(
"The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s",
name,
"backward" if is_bw else "forward",
)
self.parents.add(name)
return fn
def _get_pop_fn(self, name, is_bw):
def fn(*args):
if name in self.parents:
self.parents.remove(name)
else:
logger.info(
"The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s",
name,
"backward" if is_bw else "forward",
)
return fn
def _fw_pre_hook(self, mod, input):
name = self._get_mod_name(mod)
self._get_append_fn(name, False)()
args, _ = tree_flatten(input)
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
if tensors:
register_multi_grad_hook(tensors, self._get_pop_fn(name, True))
def _fw_post_hook(self, mod, input, output):
name = self._get_mod_name(mod)
self._get_pop_fn(name, False)()
args, _ = tree_flatten(output)
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
if tensors:
register_multi_grad_hook(tensors, self._get_append_fn(name, True))
def __enter__(self):
self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
self._fw_post_handle = register_module_forward_hook(self._fw_post_hook)
return self
def __exit__(self, *args):
self._fw_pre_handle.remove()
self._fw_post_handle.remove()