create_feature_extractor¶
-
torchvision.models.feature_extraction.
create_feature_extractor
(model: torch.nn.modules.module.Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False) → torch.fx.graph_module.GraphModule[source]¶ Creates a new graph module that returns intermediate nodes from a given model as dictionary with user specified keys as strings, and the requested outputs as values. This is achieved by re-writing the computation graph of the model via FX to return the desired nodes as outputs. All unused nodes are removed, together with their corresponding parameters.
Desired output nodes must be specified as a
.
separated path walking the module hierarchy from top level module down to leaf operation or leaf module. For more details on the node naming conventions used here, please see the relevant subheading in the documentation.Not all models will be FX traceable, although with some massaging they can be made to cooperate. Here’s a (not exhaustive) list of tips:
If you don’t need to trace through a particular, problematic sub-module, turn it into a “leaf module” by passing a list of
leaf_modules
as one of thetracer_kwargs
(see example below). It will not be traced through, but rather, the resulting graph will hold a reference to that module’s forward method.Likewise, you may turn functions into leaf functions by passing a list of
autowrap_functions
as one of thetracer_kwargs
(see example below).Some inbuilt Python functions can be problematic. For instance,
int
will raise an error during tracing. You may wrap them in your own function and then pass that inautowrap_functions
as one of thetracer_kwargs
.
For further information on FX see the torch.fx documentation.
- Parameters
model (nn.Module) – model on which we will extract the features
return_nodes (list or dict, optional) – either a
List
or aDict
containing the names (or partial names - see note above) of the nodes for which the activations will be returned. If it is aDict
, the keys are the node names, and the values are the user-specified keys for the graph module’s returned dictionary. If it is aList
, it is treated as aDict
mapping node specification strings directly to output names. In the case thattrain_return_nodes
andeval_return_nodes
are specified, this should not be specified.train_return_nodes (list or dict, optional) – similar to
return_nodes
. This can be used if the return nodes for train mode are different than those from eval mode. If this is specified,eval_return_nodes
must also be specified, andreturn_nodes
should not be specified.eval_return_nodes (list or dict, optional) – similar to
return_nodes
. This can be used if the return nodes for train mode are different than those from eval mode. If this is specified,train_return_nodes
must also be specified, and return_nodes should not be specified.tracer_kwargs (dict, optional) – a dictionary of keywork arguments for
NodePathTracer
(which passes them onto it’s parent class torch.fx.Tracer). By default it will be set to wrap and make leaf nodes all torchvision ops.suppress_diff_warning (bool, optional) – whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False.
Examples:
>>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> model = create_feature_extractor( >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = model(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] >>> # Specifying leaf modules and leaf functions >>> def leaf_function(x): >>> # This would raise a TypeError if traced through >>> return int(x) >>> >>> class LeafModule(torch.nn.Module): >>> def forward(self, x): >>> # This would raise a TypeError if traced through >>> int(x.shape[0]) >>> return torch.nn.functional.relu(x + 4) >>> >>> class MyModule(torch.nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.conv = torch.nn.Conv2d(3, 1, 3) >>> self.leaf_module = LeafModule() >>> >>> def forward(self, x): >>> leaf_function(x.shape[0]) >>> x = self.conv(x) >>> return self.leaf_module(x) >>> >>> model = create_feature_extractor( >>> MyModule(), return_nodes=['leaf_module'], >>> tracer_kwargs={'leaf_modules': [LeafModule], >>> 'autowrap_functions': [leaf_function]})