torchvision.models.feature_extraction.create_feature_extractor(model: Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) GraphModule[source]

Creates a new graph module that returns intermediate nodes from a given model as dictionary with user specified keys as strings, and the requested outputs as values. This is achieved by re-writing the computation graph of the model via FX to return the desired nodes as outputs. All unused nodes are removed, together with their corresponding parameters.

Desired output nodes must be specified as a . separated path walking the module hierarchy from top level module down to leaf operation or leaf module. For more details on the node naming conventions used here, please see the relevant subheading in the documentation.

Not all models will be FX traceable, although with some massaging they can be made to cooperate. Here’s a (not exhaustive) list of tips:

  • If you don’t need to trace through a particular, problematic sub-module, turn it into a “leaf module” by passing a list of leaf_modules as one of the tracer_kwargs (see example below). It will not be traced through, but rather, the resulting graph will hold a reference to that module’s forward method.

  • Likewise, you may turn functions into leaf functions by passing a list of autowrap_functions as one of the tracer_kwargs (see example below).

  • Some inbuilt Python functions can be problematic. For instance, int will raise an error during tracing. You may wrap them in your own function and then pass that in autowrap_functions as one of the tracer_kwargs.

For further information on FX see the torch.fx documentation.

  • model (nn.Module) – model on which we will extract the features

  • return_nodes (list or dict, optional) – either a List or a Dict containing the names (or partial names - see note above) of the nodes for which the activations will be returned. If it is a Dict, the keys are the node names, and the values are the user-specified keys for the graph module’s returned dictionary. If it is a List, it is treated as a Dict mapping node specification strings directly to output names. In the case that train_return_nodes and eval_return_nodes are specified, this should not be specified.

  • train_return_nodes (list or dict, optional) – similar to return_nodes. This can be used if the return nodes for train mode are different than those from eval mode. If this is specified, eval_return_nodes must also be specified, and return_nodes should not be specified.

  • eval_return_nodes (list or dict, optional) – similar to return_nodes. This can be used if the return nodes for train mode are different than those from eval mode. If this is specified, train_return_nodes must also be specified, and return_nodes should not be specified.

  • tracer_kwargs (dict, optional) – a dictionary of keyword arguments for NodePathTracer (which passes them onto it’s parent class torch.fx.Tracer). By default, it will be set to wrap and make leaf nodes all torchvision ops: {“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user provided dictionary.

  • suppress_diff_warning (bool, optional) – whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False.

  • concrete_args (Optional[Dict[str, any]]) – Concrete arguments that should not be treated as Proxies. According to the Pytorch docs, this parameter’s API may not be guaranteed.


>>> # Feature extraction with resnet
>>> model = torchvision.models.resnet18()
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> model = create_feature_extractor(
>>>     model, {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = model(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>>     [('feat1', torch.Size([1, 64, 56, 56])),
>>>      ('feat2', torch.Size([1, 256, 14, 14]))]

>>> # Specifying leaf modules and leaf functions
>>> def leaf_function(x):
>>>     # This would raise a TypeError if traced through
>>>     return int(x)
>>> class LeafModule(torch.nn.Module):
>>>     def forward(self, x):
>>>         # This would raise a TypeError if traced through
>>>         int(x.shape[0])
>>>         return torch.nn.functional.relu(x + 4)
>>> class MyModule(torch.nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.conv = torch.nn.Conv2d(3, 1, 3)
>>>         self.leaf_module = LeafModule()
>>>     def forward(self, x):
>>>         leaf_function(x.shape[0])
>>>         x = self.conv(x)
>>>         return self.leaf_module(x)
>>> model = create_feature_extractor(
>>>     MyModule(), return_nodes=['leaf_module'],
>>>     tracer_kwargs={'leaf_modules': [LeafModule],
>>>                    'autowrap_functions': [leaf_function]})


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources