Feature extraction for model inspection
=======================================

.. currentmodule:: torchvision.models.feature_extraction

The ``torchvision.models.feature_extraction`` package contains
feature extraction utilities that let us tap into our models to access intermediate
transformations of our inputs. This could be useful for a variety of
applications in computer vision. Just a few examples are:

- Visualizing feature maps.
- Extracting features to compute image descriptors for tasks like facial
  recognition, copy-detection, or image retrieval.
- Passing selected features to downstream sub-networks for end-to-end training
  with a specific task in mind. For example, passing a hierarchy of features
  to a Feature Pyramid Network with object detection heads.

Torchvision provides :func:`create_feature_extractor` for this purpose.
It works by following roughly these steps:

1. Symbolically tracing the model to get a graphical representation of
   how it transforms the input, step by step.
2. Setting the user-selected graph nodes as outputs.
3. Removing all redundant nodes (anything downstream of the output nodes).
4. Generating python code from the resulting graph and bundling that into a
   PyTorch module together with the graph itself.

|

The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
provides a more general and detailed explanation of the above procedure and
the inner workings of the symbolic tracing.

.. _about-node-names:

**About Node Names**

In order to specify which nodes should be output nodes for extracted
features, one should be familiar with the node naming convention used here
(which differs slightly from that used in ``torch.fx``). A node name is
specified as a ``.`` separated path walking the module hierarchy from top level
module down to leaf operation or leaf module. For instance ``"layer4.2.relu"``
in ResNet-50 represents the output of the ReLU of the 2nd block of the 4th
layer of the ``ResNet`` module. Here are some finer points to keep in mind:

- When specifying node names for :func:`create_feature_extractor`, you may
  provide a truncated version of a node name as a shortcut. To see how this
  works, try creating a ResNet-50 model and printing the node names with
  ``train_nodes, _ = get_graph_node_names(model) print(train_nodes)`` and
  observe that the last node pertaining to ``layer4`` is
  ``"layer4.2.relu_2"``. One may specify ``"layer4.2.relu_2"`` as the return
  node, or just ``"layer4"`` as this, by convention, refers to the last node
  (in order of execution) of ``layer4``.
- If a certain module or operation is repeated more than once, node names get
  an additional ``_{int}`` postfix to disambiguate. For instance, maybe the
  addition (``+``) operation is used three times in the same ``forward``
  method. Then there would be ``"path.to.module.add"``,
  ``"path.to.module.add_1"``, ``"path.to.module.add_2"``. The counter is
  maintained within the scope of the direct parent. So in ResNet-50 there is
  a ``"layer4.1.add"`` and a ``"layer4.2.add"``. Because the addition
  operations reside in different blocks, there is no need for a postfix to
  disambiguate.


**An Example**

Here is an example of how we might extract features for MaskRCNN:

.. code-block:: python

  import torch
  from torchvision.models import resnet50
  from torchvision.models.feature_extraction import get_graph_node_names
  from torchvision.models.feature_extraction import create_feature_extractor
  from torchvision.models.detection.mask_rcnn import MaskRCNN
  from torchvision.models.detection.backbone_utils import LastLevelMaxPool
  from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork


  # To assist you in designing the feature extractor you may want to print out
  # the available nodes for resnet50.
  m = resnet50()
  train_nodes, eval_nodes = get_graph_node_names(resnet50())

  # The lists returned, are the names of all the graph nodes (in order of
  # execution) for the input model traced in train mode and in eval mode
  # respectively. You'll find that `train_nodes` and `eval_nodes` are the same
  # for this example. But if the model contains control flow that's dependent
  # on the training mode, they may be different.

  # To specify the nodes you want to extract, you could select the final node
  # that appears in each of the main layers:
  return_nodes = {
      # node_name: user-specified key for output dict
      'layer1.2.relu_2': 'layer1',
      'layer2.3.relu_2': 'layer2',
      'layer3.5.relu_2': 'layer3',
      'layer4.2.relu_2': 'layer4',
  }

  # But `create_feature_extractor` can also accept truncated node specifications
  # like "layer1", as it will just pick the last node that's a descendent of
  # of the specification. (Tip: be careful with this, especially when a layer
  # has multiple outputs. It's not always guaranteed that the last operation
  # performed is the one that corresponds to the output you desire. You should
  # consult the source code for the input model to confirm.)
  return_nodes = {
      'layer1': 'layer1',
      'layer2': 'layer2',
      'layer3': 'layer3',
      'layer4': 'layer4',
  }

  # Now you can build the feature extractor. This returns a module whose forward
  # method returns a dictionary like:
  # {
  #     'layer1': output of layer 1,
  #     'layer2': output of layer 2,
  #     'layer3': output of layer 3,
  #     'layer4': output of layer 4,
  # }
  create_feature_extractor(m, return_nodes=return_nodes)

  # Let's put all that together to wrap resnet50 with MaskRCNN

  # MaskRCNN requires a backbone with an attached FPN
  class Resnet50WithFPN(torch.nn.Module):
      def __init__(self):
          super(Resnet50WithFPN, self).__init__()
          # Get a resnet50 backbone
          m = resnet50()
          # Extract 4 main layers (note: MaskRCNN needs this particular name
          # mapping for return nodes)
          self.body = create_feature_extractor(
              m, return_nodes={f'layer{k}': str(v)
                               for v, k in enumerate([1, 2, 3, 4])})
          # Dry run to get number of channels for FPN
          inp = torch.randn(2, 3, 224, 224)
          with torch.no_grad():
              out = self.body(inp)
          in_channels_list = [o.shape[1] for o in out.values()]
          # Build FPN
          self.out_channels = 256
          self.fpn = FeaturePyramidNetwork(
              in_channels_list, out_channels=self.out_channels,
              extra_blocks=LastLevelMaxPool())

      def forward(self, x):
          x = self.body(x)
          x = self.fpn(x)
          return x


  # Now we can build our model!
  model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()


API Reference
-------------

.. autosummary::
    :toctree: generated/
    :template: function.rst

    create_feature_extractor
    get_graph_node_names