• Docs >
  • torchvision.models.feature_extraction
Shortcuts

torchvision.models.feature_extraction

Feature extraction utilities let us tap into our models to access intermediate transformations of our inputs. This could be useful for a variety of applications in computer vision. Just a few examples are:

  • Visualizing feature maps.

  • Extracting features to compute image descriptors for tasks like facial recognition, copy-detection, or image retrieval.

  • Passing selected features to downstream sub-networks for end-to-end training with a specific task in mind. For example, passing a hierarchy of features to a Feature Pyramid Network with object detection heads.

Torchvision provides create_feature_extractor() for this purpose. It works by following roughly these steps:

  1. Symbolically tracing the model to get a graphical representation of how it transforms the input, step by step.

  2. Setting the user-selected graph nodes as outputs.

  3. Removing all redundant nodes (anything downstream of the output nodes).

  4. Generating python code from the resulting graph and bundling that into a PyTorch module together with the graph itself.


The torch.fx documentation provides a more general and detailed explanation of the above procedure and the inner workings of the symbolic tracing.

About Node Names

In order to specify which nodes should be output nodes for extracted features, one should be familiar with the node naming convention used here (which differs slightly from that used in torch.fx). A node name is specified as a . separated path walking the module hierarchy from top level module down to leaf operation or leaf module. For instance "layer4.2.relu" in ResNet-50 represents the output of the ReLU of the 2nd block of the 4th layer of the ResNet module. Here are some finer points to keep in mind:

  • When specifying node names for create_feature_extractor(), you may provide a truncated version of a node name as a shortcut. To see how this works, try creating a ResNet-50 model and printing the node names with train_nodes, _ = get_graph_node_names(model) print(train_nodes) and observe that the last node pertaining to layer4 is "layer4.2.relu_2". One may specify "layer4.2.relu_2" as the return node, or just "layer4" as this, by convention, refers to the last node (in order of execution) of layer4.

  • If a certain module or operation is repeated more than once, node names get an additional _{int} postfix to disambiguate. For instance, maybe the addition (+) operation is used three times in the same forward method. Then there would be "path.to.module.add", "path.to.module.add_1", "path.to.module.add_2". The counter is maintained within the scope of the direct parent. So in ResNet-50 there is a "layer4.1.add" and a "layer4.2.add". Because the addition operations reside in different blocks, there is no need for a postfix to disambiguate.

An Example

Here is an example of how we might extract features for MaskRCNN:

import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork


# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50()
train_nodes, eval_nodes = get_graph_node_names(resnet50())

# The lists returned, are the names of all the graph nodes (in order of
# execution) for the input model traced in train mode and in eval mode
# respectively. You'll find that `train_nodes` and `eval_nodes` are the same
# for this example. But if the model contains control flow that's dependent
# on the training mode, they may be different.

# To specify the nodes you want to extract, you could select the final node
# that appears in each of the main layers:
return_nodes = {
    # node_name: user-specified key for output dict
    'layer1.2.relu_2': 'layer1',
    'layer2.3.relu_2': 'layer2',
    'layer3.5.relu_2': 'layer3',
    'layer4.2.relu_2': 'layer4',
}

# But `create_feature_extractor` can also accept truncated node specifications
# like "layer1", as it will just pick the last node that's a descendent of
# of the specification. (Tip: be careful with this, especially when a layer
# has multiple outputs. It's not always guaranteed that the last operation
# performed is the one that corresponds to the output you desire. You should
# consult the source code for the input model to confirm.)
return_nodes = {
    'layer1': 'layer1',
    'layer2': 'layer2',
    'layer3': 'layer3',
    'layer4': 'layer4',
}

# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
#     'layer1': output of layer 1,
#     'layer2': output of layer 2,
#     'layer3': output of layer 3,
#     'layer4': output of layer 4,
# }
create_feature_extractor(m, return_nodes=return_nodes)

# Let's put all that together to wrap resnet50 with MaskRCNN

# MaskRCNN requires a backbone with an attached FPN
class Resnet50WithFPN(torch.nn.Module):
    def __init__(self):
        super(Resnet50WithFPN, self).__init__()
        # Get a resnet50 backbone
        m = resnet50()
        # Extract 4 main layers (note: MaskRCNN needs this particular name
        # mapping for return nodes)
        self.body = create_feature_extractor(
            m, return_nodes={f'layer{k}': str(v)
                             for v, k in enumerate([1, 2, 3, 4])})
        # Dry run to get number of channels for FPN
        inp = torch.randn(2, 3, 224, 224)
        with torch.no_grad():
            out = self.body(inp)
        in_channels_list = [o.shape[1] for o in out.values()]
        # Build FPN
        self.out_channels = 256
        self.fpn = FeaturePyramidNetwork(
            in_channels_list, out_channels=self.out_channels,
            extra_blocks=LastLevelMaxPool())

    def forward(self, x):
        x = self.body(x)
        x = self.fpn(x)
        return x


# Now we can build our model!
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()

API Reference

torchvision.models.feature_extraction.create_feature_extractor(model: torch.nn.modules.module.Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False)torch.fx.graph_module.GraphModule[source]

Creates a new graph module that returns intermediate nodes from a given model as dictionary with user specified keys as strings, and the requested outputs as values. This is achieved by re-writing the computation graph of the model via FX to return the desired nodes as outputs. All unused nodes are removed, together with their corresponding parameters.

Desired output nodes must be specified as a . separated path walking the module hierarchy from top level module down to leaf operation or leaf module. For more details on the node naming conventions used here, please see the relevant subheading in the documentation.

Not all models will be FX traceable, although with some massaging they can be made to cooperate. Here’s a (not exhaustive) list of tips:

  • If you don’t need to trace through a particular, problematic sub-module, turn it into a “leaf module” by passing a list of leaf_modules as one of the tracer_kwargs (see example below). It will not be traced through, but rather, the resulting graph will hold a reference to that module’s forward method.

  • Likewise, you may turn functions into leaf functions by passing a list of autowrap_functions as one of the tracer_kwargs (see example below).

  • Some inbuilt Python functions can be problematic. For instance, int will raise an error during tracing. You may wrap them in your own function and then pass that in autowrap_functions as one of the tracer_kwargs.

For further information on FX see the torch.fx documentation.

Parameters
  • model (nn.Module) – model on which we will extract the features

  • return_nodes (list or dict, optional) – either a List or a Dict containing the names (or partial names - see note above) of the nodes for which the activations will be returned. If it is a Dict, the keys are the node names, and the values are the user-specified keys for the graph module’s returned dictionary. If it is a List, it is treated as a Dict mapping node specification strings directly to output names. In the case that train_return_nodes and eval_return_nodes are specified, this should not be specified.

  • train_return_nodes (list or dict, optional) – similar to return_nodes. This can be used if the return nodes for train mode are different than those from eval mode. If this is specified, eval_return_nodes must also be specified, and return_nodes should not be specified.

  • eval_return_nodes (list or dict, optional) – similar to return_nodes. This can be used if the return nodes for train mode are different than those from eval mode. If this is specified, train_return_nodes must also be specified, and return_nodes should not be specified.

  • tracer_kwargs (dict, optional) – a dictionary of keywork arguments for NodePathTracer (which passes them onto it’s parent class torch.fx.Tracer).

  • suppress_diff_warning (bool, optional) – whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False.

Examples:

>>> # Feature extraction with resnet
>>> model = torchvision.models.resnet18()
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> model = create_feature_extractor(
>>>     model, {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = model(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>>     [('feat1', torch.Size([1, 64, 56, 56])),
>>>      ('feat2', torch.Size([1, 256, 14, 14]))]

>>> # Specifying leaf modules and leaf functions
>>> def leaf_function(x):
>>>     # This would raise a TypeError if traced through
>>>     return int(x)
>>>
>>> class LeafModule(torch.nn.Module):
>>>     def forward(self, x):
>>>         # This would raise a TypeError if traced through
>>>         int(x.shape[0])
>>>         return torch.nn.functional.relu(x + 4)
>>>
>>> class MyModule(torch.nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.conv = torch.nn.Conv2d(3, 1, 3)
>>>         self.leaf_module = LeafModule()
>>>
>>>     def forward(self, x):
>>>         leaf_function(x.shape[0])
>>>         x = self.conv(x)
>>>         return self.leaf_module(x)
>>>
>>> model = create_feature_extractor(
>>>     MyModule(), return_nodes=['leaf_module'],
>>>     tracer_kwargs={'leaf_modules': [LeafModule],
>>>                    'autowrap_functions': [leaf_function]})
torchvision.models.feature_extraction.get_graph_node_names(model: torch.nn.modules.module.Module, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False)Tuple[List[str], List[str]][source]

Dev utility to return node names in order of execution. See note on node names under create_feature_extractor(). Useful for seeing which node names are available for feature extraction. There are two reasons that node names can’t easily be read directly from the code for a model:

  1. Not all submodules are traced through. Modules from torch.nn all fall within this category.

  2. Nodes representing the repeated application of the same operation or leaf module get a _{counter} postfix.

The model is traced twice: once in train mode, and once in eval mode. Both sets of node names are returned.

For more details on the node naming conventions used here, please see the relevant subheading in the documentation.

Parameters
  • model (nn.Module) – model for which we’d like to print node names

  • tracer_kwargs (dict, optional) –

    a dictionary of keywork arguments for NodePathTracer (they are eventually passed onto torch.fx.Tracer).

  • suppress_diff_warning (bool, optional) – whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False.

Returns

a list of node names from tracing the model in train mode, and another from tracing the model in eval mode.

Return type

tuple(list, list)

Examples:

>>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources