fromtypingimportTuple,Optional,Callable,List,Sequence,Type,Any,Unionimporttorch.nnasnnfromtorchimportTensorfrom..._internally_replaced_utilsimportload_state_dict_from_urlfrom...utilsimport_log_api_usage_once__all__=["r3d_18","mc3_18","r2plus1d_18"]model_urls={"r3d_18":"https://download.pytorch.org/models/r3d_18-b3b3357e.pth","mc3_18":"https://download.pytorch.org/models/mc3_18-a90a0ba3.pth","r2plus1d_18":"https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",}classConv3DSimple(nn.Conv3d):def__init__(self,in_planes:int,out_planes:int,midplanes:Optional[int]=None,stride:int=1,padding:int=1)->None:super().__init__(in_channels=in_planes,out_channels=out_planes,kernel_size=(3,3,3),stride=stride,padding=padding,bias=False,)@staticmethoddefget_downsample_stride(stride:int)->Tuple[int,int,int]:returnstride,stride,strideclassConv2Plus1D(nn.Sequential):def__init__(self,in_planes:int,out_planes:int,midplanes:int,stride:int=1,padding:int=1)->None:super().__init__(nn.Conv3d(in_planes,midplanes,kernel_size=(1,3,3),stride=(1,stride,stride),padding=(0,padding,padding),bias=False,),nn.BatchNorm3d(midplanes),nn.ReLU(inplace=True),nn.Conv3d(midplanes,out_planes,kernel_size=(3,1,1),stride=(stride,1,1),padding=(padding,0,0),bias=False),)@staticmethoddefget_downsample_stride(stride:int)->Tuple[int,int,int]:returnstride,stride,strideclassConv3DNoTemporal(nn.Conv3d):def__init__(self,in_planes:int,out_planes:int,midplanes:Optional[int]=None,stride:int=1,padding:int=1)->None:super().__init__(in_channels=in_planes,out_channels=out_planes,kernel_size=(1,3,3),stride=(1,stride,stride),padding=(0,padding,padding),bias=False,)@staticmethoddefget_downsample_stride(stride:int)->Tuple[int,int,int]:return1,stride,strideclassBasicBlock(nn.Module):expansion=1def__init__(self,inplanes:int,planes:int,conv_builder:Callable[...,nn.Module],stride:int=1,downsample:Optional[nn.Module]=None,)->None:midplanes=(inplanes*planes*3*3*3)//(inplanes*3*3+3*planes)super().__init__()self.conv1=nn.Sequential(conv_builder(inplanes,planes,midplanes,stride),nn.BatchNorm3d(planes),nn.ReLU(inplace=True))self.conv2=nn.Sequential(conv_builder(planes,planes,midplanes),nn.BatchNorm3d(planes))self.relu=nn.ReLU(inplace=True)self.downsample=downsampleself.stride=stridedefforward(self,x:Tensor)->Tensor:residual=xout=self.conv1(x)out=self.conv2(out)ifself.downsampleisnotNone:residual=self.downsample(x)out+=residualout=self.relu(out)returnoutclassBottleneck(nn.Module):expansion=4def__init__(self,inplanes:int,planes:int,conv_builder:Callable[...,nn.Module],stride:int=1,downsample:Optional[nn.Module]=None,)->None:super().__init__()midplanes=(inplanes*planes*3*3*3)//(inplanes*3*3+3*planes)# 1x1x1self.conv1=nn.Sequential(nn.Conv3d(inplanes,planes,kernel_size=1,bias=False),nn.BatchNorm3d(planes),nn.ReLU(inplace=True))# Second kernelself.conv2=nn.Sequential(conv_builder(planes,planes,midplanes,stride),nn.BatchNorm3d(planes),nn.ReLU(inplace=True))# 1x1x1self.conv3=nn.Sequential(nn.Conv3d(planes,planes*self.expansion,kernel_size=1,bias=False),nn.BatchNorm3d(planes*self.expansion),)self.relu=nn.ReLU(inplace=True)self.downsample=downsampleself.stride=stridedefforward(self,x:Tensor)->Tensor:residual=xout=self.conv1(x)out=self.conv2(out)out=self.conv3(out)ifself.downsampleisnotNone:residual=self.downsample(x)out+=residualout=self.relu(out)returnoutclassBasicStem(nn.Sequential):"""The default conv-batchnorm-relu stem"""def__init__(self)->None:super().__init__(nn.Conv3d(3,64,kernel_size=(3,7,7),stride=(1,2,2),padding=(1,3,3),bias=False),nn.BatchNorm3d(64),nn.ReLU(inplace=True),)classR2Plus1dStem(nn.Sequential):"""R(2+1)D stem is different than the default one as it uses separated 3D convolution"""def__init__(self)->None:super().__init__(nn.Conv3d(3,45,kernel_size=(1,7,7),stride=(1,2,2),padding=(0,3,3),bias=False),nn.BatchNorm3d(45),nn.ReLU(inplace=True),nn.Conv3d(45,64,kernel_size=(3,1,1),stride=(1,1,1),padding=(1,0,0),bias=False),nn.BatchNorm3d(64),nn.ReLU(inplace=True),)classVideoResNet(nn.Module):def__init__(self,block:Type[Union[BasicBlock,Bottleneck]],conv_makers:Sequence[Type[Union[Conv3DSimple,Conv3DNoTemporal,Conv2Plus1D]]],layers:List[int],stem:Callable[...,nn.Module],num_classes:int=400,zero_init_residual:bool=False,)->None:"""Generic resnet video generator. Args: block (Type[Union[BasicBlock, Bottleneck]]): resnet building block conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator function for each layer layers (List[int]): number of blocks per layer stem (Callable[..., nn.Module]): module specifying the ResNet stem. num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. """super().__init__()_log_api_usage_once(self)self.inplanes=64self.stem=stem()self.layer1=self._make_layer(block,conv_makers[0],64,layers[0],stride=1)self.layer2=self._make_layer(block,conv_makers[1],128,layers[1],stride=2)self.layer3=self._make_layer(block,conv_makers[2],256,layers[2],stride=2)self.layer4=self._make_layer(block,conv_makers[3],512,layers[3],stride=2)self.avgpool=nn.AdaptiveAvgPool3d((1,1,1))self.fc=nn.Linear(512*block.expansion,num_classes)# init weightsforminself.modules():ifisinstance(m,nn.Conv3d):nn.init.kaiming_normal_(m.weight,mode="fan_out",nonlinearity="relu")ifm.biasisnotNone:nn.init.constant_(m.bias,0)elifisinstance(m,nn.BatchNorm3d):nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)elifisinstance(m,nn.Linear):nn.init.normal_(m.weight,0,0.01)nn.init.constant_(m.bias,0)ifzero_init_residual:forminself.modules():ifisinstance(m,Bottleneck):nn.init.constant_(m.bn3.weight,0)# type: ignore[union-attr, arg-type]defforward(self,x:Tensor)->Tensor:x=self.stem(x)x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)x=self.layer4(x)x=self.avgpool(x)# Flatten the layer to fcx=x.flatten(1)x=self.fc(x)returnxdef_make_layer(self,block:Type[Union[BasicBlock,Bottleneck]],conv_builder:Type[Union[Conv3DSimple,Conv3DNoTemporal,Conv2Plus1D]],planes:int,blocks:int,stride:int=1,)->nn.Sequential:downsample=Noneifstride!=1orself.inplanes!=planes*block.expansion:ds_stride=conv_builder.get_downsample_stride(stride)downsample=nn.Sequential(nn.Conv3d(self.inplanes,planes*block.expansion,kernel_size=1,stride=ds_stride,bias=False),nn.BatchNorm3d(planes*block.expansion),)layers=[]layers.append(block(self.inplanes,planes,conv_builder,stride,downsample))self.inplanes=planes*block.expansionforiinrange(1,blocks):layers.append(block(self.inplanes,planes,conv_builder))returnnn.Sequential(*layers)def_video_resnet(arch:str,pretrained:bool=False,progress:bool=True,**kwargs:Any)->VideoResNet:model=VideoResNet(**kwargs)ifpretrained:state_dict=load_state_dict_from_url(model_urls[arch],progress=progress)model.load_state_dict(state_dict)returnmodel
[docs]defr3d_18(pretrained:bool=False,progress:bool=True,**kwargs:Any)->VideoResNet:"""Construct 18 layer Resnet3D model as in https://arxiv.org/abs/1711.11248 Args: pretrained (bool): If True, returns a model pre-trained on Kinetics-400 progress (bool): If True, displays a progress bar of the download to stderr Returns: nn.Module: R3D-18 network """return_video_resnet("r3d_18",pretrained,progress,block=BasicBlock,conv_makers=[Conv3DSimple]*4,layers=[2,2,2,2],stem=BasicStem,**kwargs,)
[docs]defmc3_18(pretrained:bool=False,progress:bool=True,**kwargs:Any)->VideoResNet:"""Constructor for 18 layer Mixed Convolution network as in https://arxiv.org/abs/1711.11248 Args: pretrained (bool): If True, returns a model pre-trained on Kinetics-400 progress (bool): If True, displays a progress bar of the download to stderr Returns: nn.Module: MC3 Network definition """return_video_resnet("mc3_18",pretrained,progress,block=BasicBlock,conv_makers=[Conv3DSimple]+[Conv3DNoTemporal]*3,# type: ignore[list-item]layers=[2,2,2,2],stem=BasicStem,**kwargs,)
[docs]defr2plus1d_18(pretrained:bool=False,progress:bool=True,**kwargs:Any)->VideoResNet:"""Constructor for the 18 layer deep R(2+1)D network as in https://arxiv.org/abs/1711.11248 Args: pretrained (bool): If True, returns a model pre-trained on Kinetics-400 progress (bool): If True, displays a progress bar of the download to stderr Returns: nn.Module: R(2+1)D-18 network """return_video_resnet("r2plus1d_18",pretrained,progress,block=BasicBlock,conv_makers=[Conv2Plus1D]*4,layers=[2,2,2,2],stem=R2Plus1dStem,**kwargs,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.