fromfunctoolsimportpartialfromtypingimportAny,Callable,List,Optional,Sequence,Tuple,Type,Unionimporttorch.nnasnnfromtorchimportTensorfrom...transforms._presetsimportVideoClassificationfrom...utilsimport_log_api_usage_oncefrom.._apiimportregister_model,Weights,WeightsEnumfrom.._metaimport_KINETICS400_CATEGORIESfrom.._utilsimport_ovewrite_named_param,handle_legacy_interface__all__=["VideoResNet","R3D_18_Weights","MC3_18_Weights","R2Plus1D_18_Weights","r3d_18","mc3_18","r2plus1d_18",]classConv3DSimple(nn.Conv3d):def__init__(self,in_planes:int,out_planes:int,midplanes:Optional[int]=None,stride:int=1,padding:int=1)->None:super().__init__(in_channels=in_planes,out_channels=out_planes,kernel_size=(3,3,3),stride=stride,padding=padding,bias=False,)@staticmethoddefget_downsample_stride(stride:int)->Tuple[int,int,int]:returnstride,stride,strideclassConv2Plus1D(nn.Sequential):def__init__(self,in_planes:int,out_planes:int,midplanes:int,stride:int=1,padding:int=1)->None:super().__init__(nn.Conv3d(in_planes,midplanes,kernel_size=(1,3,3),stride=(1,stride,stride),padding=(0,padding,padding),bias=False,),nn.BatchNorm3d(midplanes),nn.ReLU(inplace=True),nn.Conv3d(midplanes,out_planes,kernel_size=(3,1,1),stride=(stride,1,1),padding=(padding,0,0),bias=False),)@staticmethoddefget_downsample_stride(stride:int)->Tuple[int,int,int]:returnstride,stride,strideclassConv3DNoTemporal(nn.Conv3d):def__init__(self,in_planes:int,out_planes:int,midplanes:Optional[int]=None,stride:int=1,padding:int=1)->None:super().__init__(in_channels=in_planes,out_channels=out_planes,kernel_size=(1,3,3),stride=(1,stride,stride),padding=(0,padding,padding),bias=False,)@staticmethoddefget_downsample_stride(stride:int)->Tuple[int,int,int]:return1,stride,strideclassBasicBlock(nn.Module):expansion=1def__init__(self,inplanes:int,planes:int,conv_builder:Callable[...,nn.Module],stride:int=1,downsample:Optional[nn.Module]=None,)->None:midplanes=(inplanes*planes*3*3*3)//(inplanes*3*3+3*planes)super().__init__()self.conv1=nn.Sequential(conv_builder(inplanes,planes,midplanes,stride),nn.BatchNorm3d(planes),nn.ReLU(inplace=True))self.conv2=nn.Sequential(conv_builder(planes,planes,midplanes),nn.BatchNorm3d(planes))self.relu=nn.ReLU(inplace=True)self.downsample=downsampleself.stride=stridedefforward(self,x:Tensor)->Tensor:residual=xout=self.conv1(x)out=self.conv2(out)ifself.downsampleisnotNone:residual=self.downsample(x)out+=residualout=self.relu(out)returnoutclassBottleneck(nn.Module):expansion=4def__init__(self,inplanes:int,planes:int,conv_builder:Callable[...,nn.Module],stride:int=1,downsample:Optional[nn.Module]=None,)->None:super().__init__()midplanes=(inplanes*planes*3*3*3)//(inplanes*3*3+3*planes)# 1x1x1self.conv1=nn.Sequential(nn.Conv3d(inplanes,planes,kernel_size=1,bias=False),nn.BatchNorm3d(planes),nn.ReLU(inplace=True))# Second kernelself.conv2=nn.Sequential(conv_builder(planes,planes,midplanes,stride),nn.BatchNorm3d(planes),nn.ReLU(inplace=True))# 1x1x1self.conv3=nn.Sequential(nn.Conv3d(planes,planes*self.expansion,kernel_size=1,bias=False),nn.BatchNorm3d(planes*self.expansion),)self.relu=nn.ReLU(inplace=True)self.downsample=downsampleself.stride=stridedefforward(self,x:Tensor)->Tensor:residual=xout=self.conv1(x)out=self.conv2(out)out=self.conv3(out)ifself.downsampleisnotNone:residual=self.downsample(x)out+=residualout=self.relu(out)returnoutclassBasicStem(nn.Sequential):"""The default conv-batchnorm-relu stem"""def__init__(self)->None:super().__init__(nn.Conv3d(3,64,kernel_size=(3,7,7),stride=(1,2,2),padding=(1,3,3),bias=False),nn.BatchNorm3d(64),nn.ReLU(inplace=True),)classR2Plus1dStem(nn.Sequential):"""R(2+1)D stem is different than the default one as it uses separated 3D convolution"""def__init__(self)->None:super().__init__(nn.Conv3d(3,45,kernel_size=(1,7,7),stride=(1,2,2),padding=(0,3,3),bias=False),nn.BatchNorm3d(45),nn.ReLU(inplace=True),nn.Conv3d(45,64,kernel_size=(3,1,1),stride=(1,1,1),padding=(1,0,0),bias=False),nn.BatchNorm3d(64),nn.ReLU(inplace=True),)classVideoResNet(nn.Module):def__init__(self,block:Type[Union[BasicBlock,Bottleneck]],conv_makers:Sequence[Type[Union[Conv3DSimple,Conv3DNoTemporal,Conv2Plus1D]]],layers:List[int],stem:Callable[...,nn.Module],num_classes:int=400,zero_init_residual:bool=False,)->None:"""Generic resnet video generator. Args: block (Type[Union[BasicBlock, Bottleneck]]): resnet building block conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator function for each layer layers (List[int]): number of blocks per layer stem (Callable[..., nn.Module]): module specifying the ResNet stem. num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. """super().__init__()_log_api_usage_once(self)self.inplanes=64self.stem=stem()self.layer1=self._make_layer(block,conv_makers[0],64,layers[0],stride=1)self.layer2=self._make_layer(block,conv_makers[1],128,layers[1],stride=2)self.layer3=self._make_layer(block,conv_makers[2],256,layers[2],stride=2)self.layer4=self._make_layer(block,conv_makers[3],512,layers[3],stride=2)self.avgpool=nn.AdaptiveAvgPool3d((1,1,1))self.fc=nn.Linear(512*block.expansion,num_classes)# init weightsforminself.modules():ifisinstance(m,nn.Conv3d):nn.init.kaiming_normal_(m.weight,mode="fan_out",nonlinearity="relu")ifm.biasisnotNone:nn.init.constant_(m.bias,0)elifisinstance(m,nn.BatchNorm3d):nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)elifisinstance(m,nn.Linear):nn.init.normal_(m.weight,0,0.01)nn.init.constant_(m.bias,0)ifzero_init_residual:forminself.modules():ifisinstance(m,Bottleneck):nn.init.constant_(m.bn3.weight,0)# type: ignore[union-attr, arg-type]defforward(self,x:Tensor)->Tensor:x=self.stem(x)x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)x=self.layer4(x)x=self.avgpool(x)# Flatten the layer to fcx=x.flatten(1)x=self.fc(x)returnxdef_make_layer(self,block:Type[Union[BasicBlock,Bottleneck]],conv_builder:Type[Union[Conv3DSimple,Conv3DNoTemporal,Conv2Plus1D]],planes:int,blocks:int,stride:int=1,)->nn.Sequential:downsample=Noneifstride!=1orself.inplanes!=planes*block.expansion:ds_stride=conv_builder.get_downsample_stride(stride)downsample=nn.Sequential(nn.Conv3d(self.inplanes,planes*block.expansion,kernel_size=1,stride=ds_stride,bias=False),nn.BatchNorm3d(planes*block.expansion),)layers=[]layers.append(block(self.inplanes,planes,conv_builder,stride,downsample))self.inplanes=planes*block.expansionforiinrange(1,blocks):layers.append(block(self.inplanes,planes,conv_builder))returnnn.Sequential(*layers)def_video_resnet(block:Type[Union[BasicBlock,Bottleneck]],conv_makers:Sequence[Type[Union[Conv3DSimple,Conv3DNoTemporal,Conv2Plus1D]]],layers:List[int],stem:Callable[...,nn.Module],weights:Optional[WeightsEnum],progress:bool,**kwargs:Any,)->VideoResNet:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=VideoResNet(block,conv_makers,layers,stem,**kwargs)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel_COMMON_META={"min_size":(1,1),"categories":_KINETICS400_CATEGORIES,"recipe":"https://github.com/pytorch/vision/tree/main/references/video_classification","_docs":("The weights reproduce closely the accuracy of the paper. The accuracies are estimated on video-level ""with parameters `frame_rate=15`, `clips_per_video=5`, and `clip_len=16`."),}
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",R3D_18_Weights.KINETICS400_V1))defr3d_18(*,weights:Optional[R3D_18_Weights]=None,progress:bool=True,**kwargs:Any)->VideoResNet:"""Construct 18 layer Resnet3D model. .. betastatus:: video module Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__. Args: weights (:class:`~torchvision.models.video.R3D_18_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.video.R3D_18_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.video.R3D_18_Weights :members: """weights=R3D_18_Weights.verify(weights)return_video_resnet(BasicBlock,[Conv3DSimple]*4,[2,2,2,2],BasicStem,weights,progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",MC3_18_Weights.KINETICS400_V1))defmc3_18(*,weights:Optional[MC3_18_Weights]=None,progress:bool=True,**kwargs:Any)->VideoResNet:"""Construct 18 layer Mixed Convolution network as in .. betastatus:: video module Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__. Args: weights (:class:`~torchvision.models.video.MC3_18_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.video.MC3_18_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.video.MC3_18_Weights :members: """weights=MC3_18_Weights.verify(weights)return_video_resnet(BasicBlock,[Conv3DSimple]+[Conv3DNoTemporal]*3,# type: ignore[list-item][2,2,2,2],BasicStem,weights,progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",R2Plus1D_18_Weights.KINETICS400_V1))defr2plus1d_18(*,weights:Optional[R2Plus1D_18_Weights]=None,progress:bool=True,**kwargs:Any)->VideoResNet:"""Construct 18 layer deep R(2+1)D network as in .. betastatus:: video module Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__. Args: weights (:class:`~torchvision.models.video.R2Plus1D_18_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.video.R2Plus1D_18_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.video.R2Plus1D_18_Weights :members: """weights=R2Plus1D_18_Weights.verify(weights)return_video_resnet(BasicBlock,[Conv2Plus1D]*4,[2,2,2,2],R2Plus1dStem,weights,progress,**kwargs,)
# The dictionary below is internal implementation detail and will be removed in v0.15from.._utilsimport_ModelURLsmodel_urls=_ModelURLs({"r3d_18":R3D_18_Weights.KINETICS400_V1.url,"mc3_18":MC3_18_Weights.KINETICS400_V1.url,"r2plus1d_18":R2Plus1D_18_Weights.KINETICS400_V1.url,})
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.