importwarningsfromcollectionsimportnamedtuplefromfunctoolsimportpartialfromtypingimportAny,Callable,List,Optional,Tupleimporttorchimporttorch.nn.functionalasFfromtorchimportnn,Tensorfrom..transforms._presetsimportImageClassificationfrom..utilsimport_log_api_usage_oncefrom._apiimportregister_model,Weights,WeightsEnumfrom._metaimport_IMAGENET_CATEGORIESfrom._utilsimport_ovewrite_named_param,handle_legacy_interface__all__=["Inception3","InceptionOutputs","_InceptionOutputs","Inception_V3_Weights","inception_v3"]InceptionOutputs=namedtuple("InceptionOutputs",["logits","aux_logits"])InceptionOutputs.__annotations__={"logits":Tensor,"aux_logits":Optional[Tensor]}# Script annotations failed with _GoogleNetOutputs = namedtuple ...# _InceptionOutputs set here for backwards compat_InceptionOutputs=InceptionOutputsclassInception3(nn.Module):def__init__(self,num_classes:int=1000,aux_logits:bool=True,transform_input:bool=False,inception_blocks:Optional[List[Callable[...,nn.Module]]]=None,init_weights:Optional[bool]=None,dropout:float=0.5,)->None:super().__init__()_log_api_usage_once(self)ifinception_blocksisNone:inception_blocks=[BasicConv2d,InceptionA,InceptionB,InceptionC,InceptionD,InceptionE,InceptionAux]ifinit_weightsisNone:warnings.warn("The default weight initialization of inception_v3 will be changed in future releases of ""torchvision. If you wish to keep the old behavior (which leads to long initialization times"" due to scipy/scipy#11299), please set init_weights=True.",FutureWarning,)init_weights=Trueiflen(inception_blocks)!=7:raiseValueError(f"length of inception_blocks should be 7 instead of {len(inception_blocks)}")conv_block=inception_blocks[0]inception_a=inception_blocks[1]inception_b=inception_blocks[2]inception_c=inception_blocks[3]inception_d=inception_blocks[4]inception_e=inception_blocks[5]inception_aux=inception_blocks[6]self.aux_logits=aux_logitsself.transform_input=transform_inputself.Conv2d_1a_3x3=conv_block(3,32,kernel_size=3,stride=2)self.Conv2d_2a_3x3=conv_block(32,32,kernel_size=3)self.Conv2d_2b_3x3=conv_block(32,64,kernel_size=3,padding=1)self.maxpool1=nn.MaxPool2d(kernel_size=3,stride=2)self.Conv2d_3b_1x1=conv_block(64,80,kernel_size=1)self.Conv2d_4a_3x3=conv_block(80,192,kernel_size=3)self.maxpool2=nn.MaxPool2d(kernel_size=3,stride=2)self.Mixed_5b=inception_a(192,pool_features=32)self.Mixed_5c=inception_a(256,pool_features=64)self.Mixed_5d=inception_a(288,pool_features=64)self.Mixed_6a=inception_b(288)self.Mixed_6b=inception_c(768,channels_7x7=128)self.Mixed_6c=inception_c(768,channels_7x7=160)self.Mixed_6d=inception_c(768,channels_7x7=160)self.Mixed_6e=inception_c(768,channels_7x7=192)self.AuxLogits:Optional[nn.Module]=Noneifaux_logits:self.AuxLogits=inception_aux(768,num_classes)self.Mixed_7a=inception_d(768)self.Mixed_7b=inception_e(1280)self.Mixed_7c=inception_e(2048)self.avgpool=nn.AdaptiveAvgPool2d((1,1))self.dropout=nn.Dropout(p=dropout)self.fc=nn.Linear(2048,num_classes)ifinit_weights:forminself.modules():ifisinstance(m,nn.Conv2d)orisinstance(m,nn.Linear):stddev=float(m.stddev)ifhasattr(m,"stddev")else0.1# type: ignoretorch.nn.init.trunc_normal_(m.weight,mean=0.0,std=stddev,a=-2,b=2)elifisinstance(m,nn.BatchNorm2d):nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)def_transform_input(self,x:Tensor)->Tensor:ifself.transform_input:x_ch0=torch.unsqueeze(x[:,0],1)*(0.229/0.5)+(0.485-0.5)/0.5x_ch1=torch.unsqueeze(x[:,1],1)*(0.224/0.5)+(0.456-0.5)/0.5x_ch2=torch.unsqueeze(x[:,2],1)*(0.225/0.5)+(0.406-0.5)/0.5x=torch.cat((x_ch0,x_ch1,x_ch2),1)returnxdef_forward(self,x:Tensor)->Tuple[Tensor,Optional[Tensor]]:# N x 3 x 299 x 299x=self.Conv2d_1a_3x3(x)# N x 32 x 149 x 149x=self.Conv2d_2a_3x3(x)# N x 32 x 147 x 147x=self.Conv2d_2b_3x3(x)# N x 64 x 147 x 147x=self.maxpool1(x)# N x 64 x 73 x 73x=self.Conv2d_3b_1x1(x)# N x 80 x 73 x 73x=self.Conv2d_4a_3x3(x)# N x 192 x 71 x 71x=self.maxpool2(x)# N x 192 x 35 x 35x=self.Mixed_5b(x)# N x 256 x 35 x 35x=self.Mixed_5c(x)# N x 288 x 35 x 35x=self.Mixed_5d(x)# N x 288 x 35 x 35x=self.Mixed_6a(x)# N x 768 x 17 x 17x=self.Mixed_6b(x)# N x 768 x 17 x 17x=self.Mixed_6c(x)# N x 768 x 17 x 17x=self.Mixed_6d(x)# N x 768 x 17 x 17x=self.Mixed_6e(x)# N x 768 x 17 x 17aux:Optional[Tensor]=Noneifself.AuxLogitsisnotNone:ifself.training:aux=self.AuxLogits(x)# N x 768 x 17 x 17x=self.Mixed_7a(x)# N x 1280 x 8 x 8x=self.Mixed_7b(x)# N x 2048 x 8 x 8x=self.Mixed_7c(x)# N x 2048 x 8 x 8# Adaptive average poolingx=self.avgpool(x)# N x 2048 x 1 x 1x=self.dropout(x)# N x 2048 x 1 x 1x=torch.flatten(x,1)# N x 2048x=self.fc(x)# N x 1000 (num_classes)returnx,aux@torch.jit.unuseddefeager_outputs(self,x:Tensor,aux:Optional[Tensor])->InceptionOutputs:ifself.trainingandself.aux_logits:returnInceptionOutputs(x,aux)else:returnx# type: ignore[return-value]defforward(self,x:Tensor)->InceptionOutputs:x=self._transform_input(x)x,aux=self._forward(x)aux_defined=self.trainingandself.aux_logitsiftorch.jit.is_scripting():ifnotaux_defined:warnings.warn("Scripted Inception3 always returns Inception3 Tuple")returnInceptionOutputs(x,aux)else:returnself.eager_outputs(x,aux)classInceptionA(nn.Module):def__init__(self,in_channels:int,pool_features:int,conv_block:Optional[Callable[...,nn.Module]]=None)->None:super().__init__()ifconv_blockisNone:conv_block=BasicConv2dself.branch1x1=conv_block(in_channels,64,kernel_size=1)self.branch5x5_1=conv_block(in_channels,48,kernel_size=1)self.branch5x5_2=conv_block(48,64,kernel_size=5,padding=2)self.branch3x3dbl_1=conv_block(in_channels,64,kernel_size=1)self.branch3x3dbl_2=conv_block(64,96,kernel_size=3,padding=1)self.branch3x3dbl_3=conv_block(96,96,kernel_size=3,padding=1)self.branch_pool=conv_block(in_channels,pool_features,kernel_size=1)def_forward(self,x:Tensor)->List[Tensor]:branch1x1=self.branch1x1(x)branch5x5=self.branch5x5_1(x)branch5x5=self.branch5x5_2(branch5x5)branch3x3dbl=self.branch3x3dbl_1(x)branch3x3dbl=self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl=self.branch3x3dbl_3(branch3x3dbl)branch_pool=F.avg_pool2d(x,kernel_size=3,stride=1,padding=1)branch_pool=self.branch_pool(branch_pool)outputs=[branch1x1,branch5x5,branch3x3dbl,branch_pool]returnoutputsdefforward(self,x:Tensor)->Tensor:outputs=self._forward(x)returntorch.cat(outputs,1)classInceptionB(nn.Module):def__init__(self,in_channels:int,conv_block:Optional[Callable[...,nn.Module]]=None)->None:super().__init__()ifconv_blockisNone:conv_block=BasicConv2dself.branch3x3=conv_block(in_channels,384,kernel_size=3,stride=2)self.branch3x3dbl_1=conv_block(in_channels,64,kernel_size=1)self.branch3x3dbl_2=conv_block(64,96,kernel_size=3,padding=1)self.branch3x3dbl_3=conv_block(96,96,kernel_size=3,stride=2)def_forward(self,x:Tensor)->List[Tensor]:branch3x3=self.branch3x3(x)branch3x3dbl=self.branch3x3dbl_1(x)branch3x3dbl=self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl=self.branch3x3dbl_3(branch3x3dbl)branch_pool=F.max_pool2d(x,kernel_size=3,stride=2)outputs=[branch3x3,branch3x3dbl,branch_pool]returnoutputsdefforward(self,x:Tensor)->Tensor:outputs=self._forward(x)returntorch.cat(outputs,1)classInceptionC(nn.Module):def__init__(self,in_channels:int,channels_7x7:int,conv_block:Optional[Callable[...,nn.Module]]=None)->None:super().__init__()ifconv_blockisNone:conv_block=BasicConv2dself.branch1x1=conv_block(in_channels,192,kernel_size=1)c7=channels_7x7self.branch7x7_1=conv_block(in_channels,c7,kernel_size=1)self.branch7x7_2=conv_block(c7,c7,kernel_size=(1,7),padding=(0,3))self.branch7x7_3=conv_block(c7,192,kernel_size=(7,1),padding=(3,0))self.branch7x7dbl_1=conv_block(in_channels,c7,kernel_size=1)self.branch7x7dbl_2=conv_block(c7,c7,kernel_size=(7,1),padding=(3,0))self.branch7x7dbl_3=conv_block(c7,c7,kernel_size=(1,7),padding=(0,3))self.branch7x7dbl_4=conv_block(c7,c7,kernel_size=(7,1),padding=(3,0))self.branch7x7dbl_5=conv_block(c7,192,kernel_size=(1,7),padding=(0,3))self.branch_pool=conv_block(in_channels,192,kernel_size=1)def_forward(self,x:Tensor)->List[Tensor]:branch1x1=self.branch1x1(x)branch7x7=self.branch7x7_1(x)branch7x7=self.branch7x7_2(branch7x7)branch7x7=self.branch7x7_3(branch7x7)branch7x7dbl=self.branch7x7dbl_1(x)branch7x7dbl=self.branch7x7dbl_2(branch7x7dbl)branch7x7dbl=self.branch7x7dbl_3(branch7x7dbl)branch7x7dbl=self.branch7x7dbl_4(branch7x7dbl)branch7x7dbl=self.branch7x7dbl_5(branch7x7dbl)branch_pool=F.avg_pool2d(x,kernel_size=3,stride=1,padding=1)branch_pool=self.branch_pool(branch_pool)outputs=[branch1x1,branch7x7,branch7x7dbl,branch_pool]returnoutputsdefforward(self,x:Tensor)->Tensor:outputs=self._forward(x)returntorch.cat(outputs,1)classInceptionD(nn.Module):def__init__(self,in_channels:int,conv_block:Optional[Callable[...,nn.Module]]=None)->None:super().__init__()ifconv_blockisNone:conv_block=BasicConv2dself.branch3x3_1=conv_block(in_channels,192,kernel_size=1)self.branch3x3_2=conv_block(192,320,kernel_size=3,stride=2)self.branch7x7x3_1=conv_block(in_channels,192,kernel_size=1)self.branch7x7x3_2=conv_block(192,192,kernel_size=(1,7),padding=(0,3))self.branch7x7x3_3=conv_block(192,192,kernel_size=(7,1),padding=(3,0))self.branch7x7x3_4=conv_block(192,192,kernel_size=3,stride=2)def_forward(self,x:Tensor)->List[Tensor]:branch3x3=self.branch3x3_1(x)branch3x3=self.branch3x3_2(branch3x3)branch7x7x3=self.branch7x7x3_1(x)branch7x7x3=self.branch7x7x3_2(branch7x7x3)branch7x7x3=self.branch7x7x3_3(branch7x7x3)branch7x7x3=self.branch7x7x3_4(branch7x7x3)branch_pool=F.max_pool2d(x,kernel_size=3,stride=2)outputs=[branch3x3,branch7x7x3,branch_pool]returnoutputsdefforward(self,x:Tensor)->Tensor:outputs=self._forward(x)returntorch.cat(outputs,1)classInceptionE(nn.Module):def__init__(self,in_channels:int,conv_block:Optional[Callable[...,nn.Module]]=None)->None:super().__init__()ifconv_blockisNone:conv_block=BasicConv2dself.branch1x1=conv_block(in_channels,320,kernel_size=1)self.branch3x3_1=conv_block(in_channels,384,kernel_size=1)self.branch3x3_2a=conv_block(384,384,kernel_size=(1,3),padding=(0,1))self.branch3x3_2b=conv_block(384,384,kernel_size=(3,1),padding=(1,0))self.branch3x3dbl_1=conv_block(in_channels,448,kernel_size=1)self.branch3x3dbl_2=conv_block(448,384,kernel_size=3,padding=1)self.branch3x3dbl_3a=conv_block(384,384,kernel_size=(1,3),padding=(0,1))self.branch3x3dbl_3b=conv_block(384,384,kernel_size=(3,1),padding=(1,0))self.branch_pool=conv_block(in_channels,192,kernel_size=1)def_forward(self,x:Tensor)->List[Tensor]:branch1x1=self.branch1x1(x)branch3x3=self.branch3x3_1(x)branch3x3=[self.branch3x3_2a(branch3x3),self.branch3x3_2b(branch3x3),]branch3x3=torch.cat(branch3x3,1)branch3x3dbl=self.branch3x3dbl_1(x)branch3x3dbl=self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl=[self.branch3x3dbl_3a(branch3x3dbl),self.branch3x3dbl_3b(branch3x3dbl),]branch3x3dbl=torch.cat(branch3x3dbl,1)branch_pool=F.avg_pool2d(x,kernel_size=3,stride=1,padding=1)branch_pool=self.branch_pool(branch_pool)outputs=[branch1x1,branch3x3,branch3x3dbl,branch_pool]returnoutputsdefforward(self,x:Tensor)->Tensor:outputs=self._forward(x)returntorch.cat(outputs,1)classInceptionAux(nn.Module):def__init__(self,in_channels:int,num_classes:int,conv_block:Optional[Callable[...,nn.Module]]=None)->None:super().__init__()ifconv_blockisNone:conv_block=BasicConv2dself.conv0=conv_block(in_channels,128,kernel_size=1)self.conv1=conv_block(128,768,kernel_size=5)self.conv1.stddev=0.01# type: ignore[assignment]self.fc=nn.Linear(768,num_classes)self.fc.stddev=0.001# type: ignore[assignment]defforward(self,x:Tensor)->Tensor:# N x 768 x 17 x 17x=F.avg_pool2d(x,kernel_size=5,stride=3)# N x 768 x 5 x 5x=self.conv0(x)# N x 128 x 5 x 5x=self.conv1(x)# N x 768 x 1 x 1# Adaptive average poolingx=F.adaptive_avg_pool2d(x,(1,1))# N x 768 x 1 x 1x=torch.flatten(x,1)# N x 768x=self.fc(x)# N x 1000returnxclassBasicConv2d(nn.Module):def__init__(self,in_channels:int,out_channels:int,**kwargs:Any)->None:super().__init__()self.conv=nn.Conv2d(in_channels,out_channels,bias=False,**kwargs)self.bn=nn.BatchNorm2d(out_channels,eps=0.001)defforward(self,x:Tensor)->Tensor:x=self.conv(x)x=self.bn(x)returnF.relu(x,inplace=True)
[docs]classInception_V3_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",transforms=partial(ImageClassification,crop_size=299,resize_size=342),meta={"num_params":27161264,"min_size":(75,75),"categories":_IMAGENET_CATEGORIES,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#inception-v3","_metrics":{"ImageNet-1K":{"acc@1":77.294,"acc@5":93.450,}},"_ops":5.713,"_file_size":103.903,"_docs":"""These weights are ported from the original paper.""",},)DEFAULT=IMAGENET1K_V1
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Inception_V3_Weights.IMAGENET1K_V1))definception_v3(*,weights:Optional[Inception_V3_Weights]=None,progress:bool=True,**kwargs:Any)->Inception3:""" Inception v3 model architecture from `Rethinking the Inception Architecture for Computer Vision <http://arxiv.org/abs/1512.00567>`_. .. note:: **Important**: In contrast to the other models the inception_v3 expects tensors with a size of N x 3 x 299 x 299, so ensure your images are sized accordingly. Args: weights (:class:`~torchvision.models.Inception_V3_Weights`, optional): The pretrained weights for the model. See :class:`~torchvision.models.Inception_V3_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.Inception3`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py>`_ for more details about this class. .. autoclass:: torchvision.models.Inception_V3_Weights :members: """weights=Inception_V3_Weights.verify(weights)original_aux_logits=kwargs.get("aux_logits",True)ifweightsisnotNone:if"transform_input"notinkwargs:_ovewrite_named_param(kwargs,"transform_input",True)_ovewrite_named_param(kwargs,"aux_logits",True)_ovewrite_named_param(kwargs,"init_weights",False)_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=Inception3(**kwargs)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))ifnotoriginal_aux_logits:model.aux_logits=Falsemodel.AuxLogits=Nonereturnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.