get_graph_node_names¶
-
torchvision.models.feature_extraction.
get_graph_node_names
(model: torch.nn.modules.module.Module, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False) → Tuple[List[str], List[str]][source]¶ Dev utility to return node names in order of execution. See note on node names under
create_feature_extractor()
. Useful for seeing which node names are available for feature extraction. There are two reasons that node names can’t easily be read directly from the code for a model:Not all submodules are traced through. Modules from
torch.nn
all fall within this category.Nodes representing the repeated application of the same operation or leaf module get a
_{counter}
postfix.
The model is traced twice: once in train mode, and once in eval mode. Both sets of node names are returned.
For more details on the node naming conventions used here, please see the relevant subheading in the documentation.
- Parameters
model (nn.Module) – model for which we’d like to print node names
tracer_kwargs (dict, optional) – a dictionary of keywork arguments for
NodePathTracer
(they are eventually passed onto torch.fx.Tracer). By default it will be set to wrap and make leaf nodes all torchvision ops.suppress_diff_warning (bool, optional) – whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False.
- Returns
a list of node names from tracing the model in train mode, and another from tracing the model in eval mode.
- Return type
Examples:
>>> model = torchvision.models.resnet18() >>> train_nodes, eval_nodes = get_graph_node_names(model)