Source code for torchvision.ops.feature_pyramid_network
fromcollectionsimportOrderedDictfromtypingimportCallable,Dict,List,Optional,Tupleimporttorch.nn.functionalasFfromtorchimportnn,Tensorfrom..ops.miscimportConv2dNormActivationfrom..utilsimport_log_api_usage_onceclassExtraFPNBlock(nn.Module):""" Base class for the extra block in the FPN. Args: results (List[Tensor]): the result of the FPN x (List[Tensor]): the original feature maps names (List[str]): the names for each one of the original feature maps Returns: results (List[Tensor]): the extended set of results of the FPN names (List[str]): the extended set of names for the results """defforward(self,results:List[Tensor],x:List[Tensor],names:List[str],)->Tuple[List[Tensor],List[str]]:pass
[docs]classFeaturePyramidNetwork(nn.Module):""" Module that adds a FPN from on top of a set of feature maps. This is based on `"Feature Pyramid Network for Object Detection" <https://arxiv.org/abs/1612.03144>`_. The feature maps are currently supposed to be in increasing depth order. The input to the model is expected to be an OrderedDict[Tensor], containing the feature maps on top of which the FPN will be added. Args: in_channels_list (list[int]): number of channels for each feature map that is passed to the module out_channels (int): number of channels of the FPN representation extra_blocks (ExtraFPNBlock or None): if provided, extra operations will be performed. It is expected to take the fpn features, the original features and the names of the original features as input, and returns a new list of feature maps and their corresponding names norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None Examples:: >>> m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5) >>> # get some dummy data >>> x = OrderedDict() >>> x['feat0'] = torch.rand(1, 10, 64, 64) >>> x['feat2'] = torch.rand(1, 20, 16, 16) >>> x['feat3'] = torch.rand(1, 30, 8, 8) >>> # compute the FPN on top of x >>> output = m(x) >>> print([(k, v.shape) for k, v in output.items()]) >>> # returns >>> [('feat0', torch.Size([1, 5, 64, 64])), >>> ('feat2', torch.Size([1, 5, 16, 16])), >>> ('feat3', torch.Size([1, 5, 8, 8]))] """_version=2def__init__(self,in_channels_list:List[int],out_channels:int,extra_blocks:Optional[ExtraFPNBlock]=None,norm_layer:Optional[Callable[...,nn.Module]]=None,):super().__init__()_log_api_usage_once(self)self.inner_blocks=nn.ModuleList()self.layer_blocks=nn.ModuleList()forin_channelsinin_channels_list:ifin_channels==0:raiseValueError("in_channels=0 is currently not supported")inner_block_module=Conv2dNormActivation(in_channels,out_channels,kernel_size=1,padding=0,norm_layer=norm_layer,activation_layer=None)layer_block_module=Conv2dNormActivation(out_channels,out_channels,kernel_size=3,norm_layer=norm_layer,activation_layer=None)self.inner_blocks.append(inner_block_module)self.layer_blocks.append(layer_block_module)# initialize parameters now to avoid modifying the initialization of top_blocksforminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_uniform_(m.weight,a=1)ifm.biasisnotNone:nn.init.constant_(m.bias,0)ifextra_blocksisnotNone:ifnotisinstance(extra_blocks,ExtraFPNBlock):raiseTypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}")self.extra_blocks=extra_blocksdef_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,):version=local_metadata.get("version",None)ifversionisNoneorversion<2:num_blocks=len(self.inner_blocks)forblockin["inner_blocks","layer_blocks"]:foriinrange(num_blocks):fortypein["weight","bias"]:old_key=f"{prefix}{block}.{i}.{type}"new_key=f"{prefix}{block}.{i}.0.{type}"ifold_keyinstate_dict:state_dict[new_key]=state_dict.pop(old_key)super()._load_from_state_dict(state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,)
[docs]defget_result_from_inner_blocks(self,x:Tensor,idx:int)->Tensor:""" This is equivalent to self.inner_blocks[idx](x), but torchscript doesn't support this yet """num_blocks=len(self.inner_blocks)ifidx<0:idx+=num_blocksout=xfori,moduleinenumerate(self.inner_blocks):ifi==idx:out=module(x)returnout
[docs]defget_result_from_layer_blocks(self,x:Tensor,idx:int)->Tensor:""" This is equivalent to self.layer_blocks[idx](x), but torchscript doesn't support this yet """num_blocks=len(self.layer_blocks)ifidx<0:idx+=num_blocksout=xfori,moduleinenumerate(self.layer_blocks):ifi==idx:out=module(x)returnout
[docs]defforward(self,x:Dict[str,Tensor])->Dict[str,Tensor]:""" Computes the FPN for a set of feature maps. Args: x (OrderedDict[Tensor]): feature maps for each feature level. Returns: results (OrderedDict[Tensor]): feature maps after FPN layers. They are ordered from the highest resolution first. """# unpack OrderedDict into two lists for easier handlingnames=list(x.keys())x=list(x.values())last_inner=self.get_result_from_inner_blocks(x[-1],-1)results=[]results.append(self.get_result_from_layer_blocks(last_inner,-1))foridxinrange(len(x)-2,-1,-1):inner_lateral=self.get_result_from_inner_blocks(x[idx],idx)feat_shape=inner_lateral.shape[-2:]inner_top_down=F.interpolate(last_inner,size=feat_shape,mode="nearest")last_inner=inner_lateral+inner_top_downresults.insert(0,self.get_result_from_layer_blocks(last_inner,idx))ifself.extra_blocksisnotNone:results,names=self.extra_blocks(results,x,names)# make it back an OrderedDictout=OrderedDict([(k,v)fork,vinzip(names,results)])returnout
classLastLevelMaxPool(ExtraFPNBlock):""" Applies a max_pool2d (not actual max_pool2d, we just subsample) on top of the last feature map """defforward(self,x:List[Tensor],y:List[Tensor],names:List[str],)->Tuple[List[Tensor],List[str]]:names.append("pool")# Use max pooling to simulate stride 2 subsamplingx.append(F.max_pool2d(x[-1],kernel_size=1,stride=2,padding=0))returnx,namesclassLastLevelP6P7(ExtraFPNBlock):""" This module is used in RetinaNet to generate extra layers, P6 and P7. """def__init__(self,in_channels:int,out_channels:int):super().__init__()self.p6=nn.Conv2d(in_channels,out_channels,3,2,1)self.p7=nn.Conv2d(out_channels,out_channels,3,2,1)formodulein[self.p6,self.p7]:nn.init.kaiming_uniform_(module.weight,a=1)nn.init.constant_(module.bias,0)self.use_P5=in_channels==out_channelsdefforward(self,p:List[Tensor],c:List[Tensor],names:List[str],)->Tuple[List[Tensor],List[str]]:p5,c5=p[-1],c[-1]x=p5ifself.use_P5elsec5p6=self.p6(x)p7=self.p7(F.relu(p6))p.extend([p6,p7])names.extend(["p6","p7"])returnp,names
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.