importmathfromdataclassesimportdataclassfromfunctoolsimportpartialfromtypingimportAny,Callable,Dict,List,Optional,Sequence,Tupleimporttorchimporttorch.fximporttorch.nnasnnfrom...opsimportMLP,StochasticDepthfrom...transforms._presetsimportVideoClassificationfrom...utilsimport_log_api_usage_oncefrom.._apiimportregister_model,Weights,WeightsEnumfrom.._metaimport_KINETICS400_CATEGORIESfrom.._utilsimport_ovewrite_named_param,handle_legacy_interface__all__=["MViT","MViT_V1_B_Weights","mvit_v1_b","MViT_V2_S_Weights","mvit_v2_s",]@dataclassclassMSBlockConfig:num_heads:intinput_channels:intoutput_channels:intkernel_q:List[int]kernel_kv:List[int]stride_q:List[int]stride_kv:List[int]def_prod(s:Sequence[int])->int:product=1forvins:product*=vreturnproductdef_unsqueeze(x:torch.Tensor,target_dim:int,expand_dim:int)->Tuple[torch.Tensor,int]:tensor_dim=x.dim()iftensor_dim==target_dim-1:x=x.unsqueeze(expand_dim)eliftensor_dim!=target_dim:raiseValueError(f"Unsupported input dimension {x.shape}")returnx,tensor_dimdef_squeeze(x:torch.Tensor,target_dim:int,expand_dim:int,tensor_dim:int)->torch.Tensor:iftensor_dim==target_dim-1:x=x.squeeze(expand_dim)returnxtorch.fx.wrap("_unsqueeze")torch.fx.wrap("_squeeze")classPool(nn.Module):def__init__(self,pool:nn.Module,norm:Optional[nn.Module],activation:Optional[nn.Module]=None,norm_before_pool:bool=False,)->None:super().__init__()self.pool=poollayers=[]ifnormisnotNone:layers.append(norm)ifactivationisnotNone:layers.append(activation)self.norm_act=nn.Sequential(*layers)iflayerselseNoneself.norm_before_pool=norm_before_pooldefforward(self,x:torch.Tensor,thw:Tuple[int,int,int])->Tuple[torch.Tensor,Tuple[int,int,int]]:x,tensor_dim=_unsqueeze(x,4,1)# Separate the class token and reshape the inputclass_token,x=torch.tensor_split(x,indices=(1,),dim=2)x=x.transpose(2,3)B,N,C=x.shape[:3]x=x.reshape((B*N,C)+thw).contiguous()# normalizing prior pooling is useful when we use BN which can be absorbed to speed up inferenceifself.norm_before_poolandself.norm_actisnotNone:x=self.norm_act(x)# apply the pool on the input and add back the tokenx=self.pool(x)T,H,W=x.shape[2:]x=x.reshape(B,N,C,-1).transpose(2,3)x=torch.cat((class_token,x),dim=2)ifnotself.norm_before_poolandself.norm_actisnotNone:x=self.norm_act(x)x=_squeeze(x,4,1,tensor_dim)returnx,(T,H,W)def_interpolate(embedding:torch.Tensor,d:int)->torch.Tensor:ifembedding.shape[0]==d:returnembeddingreturn(nn.functional.interpolate(embedding.permute(1,0).unsqueeze(0),size=d,mode="linear",).squeeze(0).permute(1,0))def_add_rel_pos(attn:torch.Tensor,q:torch.Tensor,q_thw:Tuple[int,int,int],k_thw:Tuple[int,int,int],rel_pos_h:torch.Tensor,rel_pos_w:torch.Tensor,rel_pos_t:torch.Tensor,)->torch.Tensor:# Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932q_t,q_h,q_w=q_thwk_t,k_h,k_w=k_thwdh=int(2*max(q_h,k_h)-1)dw=int(2*max(q_w,k_w)-1)dt=int(2*max(q_t,k_t)-1)# Scale up rel pos if shapes for q and k are different.q_h_ratio=max(k_h/q_h,1.0)k_h_ratio=max(q_h/k_h,1.0)dist_h=torch.arange(q_h)[:,None]*q_h_ratio-(torch.arange(k_h)[None,:]+(1.0-k_h))*k_h_ratioq_w_ratio=max(k_w/q_w,1.0)k_w_ratio=max(q_w/k_w,1.0)dist_w=torch.arange(q_w)[:,None]*q_w_ratio-(torch.arange(k_w)[None,:]+(1.0-k_w))*k_w_ratioq_t_ratio=max(k_t/q_t,1.0)k_t_ratio=max(q_t/k_t,1.0)dist_t=torch.arange(q_t)[:,None]*q_t_ratio-(torch.arange(k_t)[None,:]+(1.0-k_t))*k_t_ratio# Interpolate rel pos if needed.rel_pos_h=_interpolate(rel_pos_h,dh)rel_pos_w=_interpolate(rel_pos_w,dw)rel_pos_t=_interpolate(rel_pos_t,dt)Rh=rel_pos_h[dist_h.long()]Rw=rel_pos_w[dist_w.long()]Rt=rel_pos_t[dist_t.long()]B,n_head,_,dim=q.shaper_q=q[:,:,1:].reshape(B,n_head,q_t,q_h,q_w,dim)rel_h_q=torch.einsum("bythwc,hkc->bythwk",r_q,Rh)# [B, H, q_t, qh, qw, k_h]rel_w_q=torch.einsum("bythwc,wkc->bythwk",r_q,Rw)# [B, H, q_t, qh, qw, k_w]# [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]r_q=r_q.permute(2,0,1,3,4,5).reshape(q_t,B*n_head*q_h*q_w,dim)# [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]rel_q_t=torch.matmul(r_q,Rt.transpose(1,2)).transpose(0,1)# [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]rel_q_t=rel_q_t.view(B,n_head,q_h,q_w,q_t,k_t).permute(0,1,4,2,3,5)# Combine rel pos.rel_pos=(rel_h_q[:,:,:,:,:,None,:,None]+rel_w_q[:,:,:,:,:,None,None,:]+rel_q_t[:,:,:,:,:,:,None,None]).reshape(B,n_head,q_t*q_h*q_w,k_t*k_h*k_w)# Add it to attentionattn[:,:,1:,1:]+=rel_posreturnattndef_add_shortcut(x:torch.Tensor,shortcut:torch.Tensor,residual_with_cls_embed:bool):ifresidual_with_cls_embed:x.add_(shortcut)else:x[:,:,1:,:]+=shortcut[:,:,1:,:]returnxtorch.fx.wrap("_add_rel_pos")torch.fx.wrap("_add_shortcut")classMultiscaleAttention(nn.Module):def__init__(self,input_size:List[int],embed_dim:int,output_dim:int,num_heads:int,kernel_q:List[int],kernel_kv:List[int],stride_q:List[int],stride_kv:List[int],residual_pool:bool,residual_with_cls_embed:bool,rel_pos_embed:bool,dropout:float=0.0,norm_layer:Callable[...,nn.Module]=nn.LayerNorm,)->None:super().__init__()self.embed_dim=embed_dimself.output_dim=output_dimself.num_heads=num_headsself.head_dim=output_dim//num_headsself.scaler=1.0/math.sqrt(self.head_dim)self.residual_pool=residual_poolself.residual_with_cls_embed=residual_with_cls_embedself.qkv=nn.Linear(embed_dim,3*output_dim)layers:List[nn.Module]=[nn.Linear(output_dim,output_dim)]ifdropout>0.0:layers.append(nn.Dropout(dropout,inplace=True))self.project=nn.Sequential(*layers)self.pool_q:Optional[nn.Module]=Noneif_prod(kernel_q)>1or_prod(stride_q)>1:padding_q=[int(q//2)forqinkernel_q]self.pool_q=Pool(nn.Conv3d(self.head_dim,self.head_dim,kernel_q,# type: ignore[arg-type]stride=stride_q,# type: ignore[arg-type]padding=padding_q,# type: ignore[arg-type]groups=self.head_dim,bias=False,),norm_layer(self.head_dim),)self.pool_k:Optional[nn.Module]=Noneself.pool_v:Optional[nn.Module]=Noneif_prod(kernel_kv)>1or_prod(stride_kv)>1:padding_kv=[int(kv//2)forkvinkernel_kv]self.pool_k=Pool(nn.Conv3d(self.head_dim,self.head_dim,kernel_kv,# type: ignore[arg-type]stride=stride_kv,# type: ignore[arg-type]padding=padding_kv,# type: ignore[arg-type]groups=self.head_dim,bias=False,),norm_layer(self.head_dim),)self.pool_v=Pool(nn.Conv3d(self.head_dim,self.head_dim,kernel_kv,# type: ignore[arg-type]stride=stride_kv,# type: ignore[arg-type]padding=padding_kv,# type: ignore[arg-type]groups=self.head_dim,bias=False,),norm_layer(self.head_dim),)self.rel_pos_h:Optional[nn.Parameter]=Noneself.rel_pos_w:Optional[nn.Parameter]=Noneself.rel_pos_t:Optional[nn.Parameter]=Noneifrel_pos_embed:size=max(input_size[1:])q_size=size//stride_q[1]iflen(stride_q)>0elsesizekv_size=size//stride_kv[1]iflen(stride_kv)>0elsesizespatial_dim=2*max(q_size,kv_size)-1temporal_dim=2*input_size[0]-1self.rel_pos_h=nn.Parameter(torch.zeros(spatial_dim,self.head_dim))self.rel_pos_w=nn.Parameter(torch.zeros(spatial_dim,self.head_dim))self.rel_pos_t=nn.Parameter(torch.zeros(temporal_dim,self.head_dim))nn.init.trunc_normal_(self.rel_pos_h,std=0.02)nn.init.trunc_normal_(self.rel_pos_w,std=0.02)nn.init.trunc_normal_(self.rel_pos_t,std=0.02)defforward(self,x:torch.Tensor,thw:Tuple[int,int,int])->Tuple[torch.Tensor,Tuple[int,int,int]]:B,N,C=x.shapeq,k,v=self.qkv(x).reshape(B,N,3,self.num_heads,self.head_dim).transpose(1,3).unbind(dim=2)ifself.pool_kisnotNone:k,k_thw=self.pool_k(k,thw)else:k_thw=thwifself.pool_visnotNone:v=self.pool_v(v,thw)[0]ifself.pool_qisnotNone:q,thw=self.pool_q(q,thw)attn=torch.matmul(self.scaler*q,k.transpose(2,3))ifself.rel_pos_hisnotNoneandself.rel_pos_wisnotNoneandself.rel_pos_tisnotNone:attn=_add_rel_pos(attn,q,thw,k_thw,self.rel_pos_h,self.rel_pos_w,self.rel_pos_t,)attn=attn.softmax(dim=-1)x=torch.matmul(attn,v)ifself.residual_pool:_add_shortcut(x,q,self.residual_with_cls_embed)x=x.transpose(1,2).reshape(B,-1,self.output_dim)x=self.project(x)returnx,thwclassMultiscaleBlock(nn.Module):def__init__(self,input_size:List[int],cnf:MSBlockConfig,residual_pool:bool,residual_with_cls_embed:bool,rel_pos_embed:bool,proj_after_attn:bool,dropout:float=0.0,stochastic_depth_prob:float=0.0,norm_layer:Callable[...,nn.Module]=nn.LayerNorm,)->None:super().__init__()self.proj_after_attn=proj_after_attnself.pool_skip:Optional[nn.Module]=Noneif_prod(cnf.stride_q)>1:kernel_skip=[s+1ifs>1elsesforsincnf.stride_q]padding_skip=[int(k//2)forkinkernel_skip]self.pool_skip=Pool(nn.MaxPool3d(kernel_skip,stride=cnf.stride_q,padding=padding_skip),None# type: ignore[arg-type])attn_dim=cnf.output_channelsifproj_after_attnelsecnf.input_channelsself.norm1=norm_layer(cnf.input_channels)self.norm2=norm_layer(attn_dim)self.needs_transposal=isinstance(self.norm1,nn.BatchNorm1d)self.attn=MultiscaleAttention(input_size,cnf.input_channels,attn_dim,cnf.num_heads,kernel_q=cnf.kernel_q,kernel_kv=cnf.kernel_kv,stride_q=cnf.stride_q,stride_kv=cnf.stride_kv,rel_pos_embed=rel_pos_embed,residual_pool=residual_pool,residual_with_cls_embed=residual_with_cls_embed,dropout=dropout,norm_layer=norm_layer,)self.mlp=MLP(attn_dim,[4*attn_dim,cnf.output_channels],activation_layer=nn.GELU,dropout=dropout,inplace=None,)self.stochastic_depth=StochasticDepth(stochastic_depth_prob,"row")self.project:Optional[nn.Module]=Noneifcnf.input_channels!=cnf.output_channels:self.project=nn.Linear(cnf.input_channels,cnf.output_channels)defforward(self,x:torch.Tensor,thw:Tuple[int,int,int])->Tuple[torch.Tensor,Tuple[int,int,int]]:x_norm1=self.norm1(x.transpose(1,2)).transpose(1,2)ifself.needs_transposalelseself.norm1(x)x_attn,thw_new=self.attn(x_norm1,thw)x=xifself.projectisNoneornotself.proj_after_attnelseself.project(x_norm1)x_skip=xifself.pool_skipisNoneelseself.pool_skip(x,thw)[0]x=x_skip+self.stochastic_depth(x_attn)x_norm2=self.norm2(x.transpose(1,2)).transpose(1,2)ifself.needs_transposalelseself.norm2(x)x_proj=xifself.projectisNoneorself.proj_after_attnelseself.project(x_norm2)returnx_proj+self.stochastic_depth(self.mlp(x_norm2)),thw_newclassPositionalEncoding(nn.Module):def__init__(self,embed_size:int,spatial_size:Tuple[int,int],temporal_size:int,rel_pos_embed:bool)->None:super().__init__()self.spatial_size=spatial_sizeself.temporal_size=temporal_sizeself.class_token=nn.Parameter(torch.zeros(embed_size))self.spatial_pos:Optional[nn.Parameter]=Noneself.temporal_pos:Optional[nn.Parameter]=Noneself.class_pos:Optional[nn.Parameter]=Noneifnotrel_pos_embed:self.spatial_pos=nn.Parameter(torch.zeros(self.spatial_size[0]*self.spatial_size[1],embed_size))self.temporal_pos=nn.Parameter(torch.zeros(self.temporal_size,embed_size))self.class_pos=nn.Parameter(torch.zeros(embed_size))defforward(self,x:torch.Tensor)->torch.Tensor:class_token=self.class_token.expand(x.size(0),-1).unsqueeze(1)x=torch.cat((class_token,x),dim=1)ifself.spatial_posisnotNoneandself.temporal_posisnotNoneandself.class_posisnotNone:hw_size,embed_size=self.spatial_pos.shapepos_embedding=torch.repeat_interleave(self.temporal_pos,hw_size,dim=0)pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size,-1,-1).reshape(-1,embed_size))pos_embedding=torch.cat((self.class_pos.unsqueeze(0),pos_embedding),dim=0).unsqueeze(0)x.add_(pos_embedding)returnxclassMViT(nn.Module):def__init__(self,spatial_size:Tuple[int,int],temporal_size:int,block_setting:Sequence[MSBlockConfig],residual_pool:bool,residual_with_cls_embed:bool,rel_pos_embed:bool,proj_after_attn:bool,dropout:float=0.5,attention_dropout:float=0.0,stochastic_depth_prob:float=0.0,num_classes:int=400,block:Optional[Callable[...,nn.Module]]=None,norm_layer:Optional[Callable[...,nn.Module]]=None,patch_embed_kernel:Tuple[int,int,int]=(3,7,7),patch_embed_stride:Tuple[int,int,int]=(2,4,4),patch_embed_padding:Tuple[int,int,int]=(1,3,3),)->None:""" MViT main class. Args: spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``. temporal_size (int): The temporal size ``T`` of the input. block_setting (sequence of MSBlockConfig): The Network structure. residual_pool (bool): If True, use MViTv2 pooling residual connection. residual_with_cls_embed (bool): If True, the addition on the residual connection will include the class embedding. rel_pos_embed (bool): If True, use MViTv2's relative positional embeddings. proj_after_attn (bool): If True, apply the projection after the attention. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. num_classes (int): The number of classes. block (callable, optional): Module specifying the layer which consists of the attention and mlp. norm_layer (callable, optional): Module specifying the normalization layer to use. patch_embed_kernel (tuple of ints): The kernel of the convolution that patchifies the input. patch_embed_stride (tuple of ints): The stride of the convolution that patchifies the input. patch_embed_padding (tuple of ints): The padding of the convolution that patchifies the input. """super().__init__()# This implementation employs a different parameterization scheme than the one used at PyTorch Video:# https://github.com/facebookresearch/pytorchvideo/blob/718d0a4/pytorchvideo/models/vision_transformers.py# We remove any experimental configuration that didn't make it to the final variants of the models. To represent# the configuration of the architecture we use the simplified form suggested at Table 1 of the paper._log_api_usage_once(self)total_stage_blocks=len(block_setting)iftotal_stage_blocks==0:raiseValueError("The configuration parameter can't be empty.")ifblockisNone:block=MultiscaleBlockifnorm_layerisNone:norm_layer=partial(nn.LayerNorm,eps=1e-6)# Patch Embedding moduleself.conv_proj=nn.Conv3d(in_channels=3,out_channels=block_setting[0].input_channels,kernel_size=patch_embed_kernel,stride=patch_embed_stride,padding=patch_embed_padding,)input_size=[size//strideforsize,strideinzip((temporal_size,)+spatial_size,self.conv_proj.stride)]# Spatio-Temporal Class Positional Encodingself.pos_encoding=PositionalEncoding(embed_size=block_setting[0].input_channels,spatial_size=(input_size[1],input_size[2]),temporal_size=input_size[0],rel_pos_embed=rel_pos_embed,)# Encoder moduleself.blocks=nn.ModuleList()forstage_block_id,cnfinenumerate(block_setting):# adjust stochastic depth probability based on the depth of the stage blocksd_prob=stochastic_depth_prob*stage_block_id/(total_stage_blocks-1.0)self.blocks.append(block(input_size=input_size,cnf=cnf,residual_pool=residual_pool,residual_with_cls_embed=residual_with_cls_embed,rel_pos_embed=rel_pos_embed,proj_after_attn=proj_after_attn,dropout=attention_dropout,stochastic_depth_prob=sd_prob,norm_layer=norm_layer,))iflen(cnf.stride_q)>0:input_size=[size//strideforsize,strideinzip(input_size,cnf.stride_q)]self.norm=norm_layer(block_setting[-1].output_channels)# Classifier moduleself.head=nn.Sequential(nn.Dropout(dropout,inplace=True),nn.Linear(block_setting[-1].output_channels,num_classes),)forminself.modules():ifisinstance(m,nn.Linear):nn.init.trunc_normal_(m.weight,std=0.02)ifisinstance(m,nn.Linear)andm.biasisnotNone:nn.init.constant_(m.bias,0.0)elifisinstance(m,nn.LayerNorm):ifm.weightisnotNone:nn.init.constant_(m.weight,1.0)ifm.biasisnotNone:nn.init.constant_(m.bias,0.0)elifisinstance(m,PositionalEncoding):forweightsinm.parameters():nn.init.trunc_normal_(weights,std=0.02)defforward(self,x:torch.Tensor)->torch.Tensor:# Convert if necessary (B, C, H, W) -> (B, C, 1, H, W)x=_unsqueeze(x,5,2)[0]# patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0])x=self.conv_proj(x)x=x.flatten(2).transpose(1,2)# add positional encodingx=self.pos_encoding(x)# pass patches through the encoderthw=(self.pos_encoding.temporal_size,)+self.pos_encoding.spatial_sizeforblockinself.blocks:x,thw=block(x,thw)x=self.norm(x)# classifier "token" as used by standard language architecturesx=x[:,0]x=self.head(x)returnxdef_mvit(block_setting:List[MSBlockConfig],stochastic_depth_prob:float,weights:Optional[WeightsEnum],progress:bool,**kwargs:Any,)->MViT:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))assertweights.meta["min_size"][0]==weights.meta["min_size"][1]_ovewrite_named_param(kwargs,"spatial_size",weights.meta["min_size"])_ovewrite_named_param(kwargs,"temporal_size",weights.meta["min_temporal_size"])spatial_size=kwargs.pop("spatial_size",(224,224))temporal_size=kwargs.pop("temporal_size",16)model=MViT(spatial_size=spatial_size,temporal_size=temporal_size,block_setting=block_setting,residual_pool=kwargs.pop("residual_pool",False),residual_with_cls_embed=kwargs.pop("residual_with_cls_embed",True),rel_pos_embed=kwargs.pop("rel_pos_embed",False),proj_after_attn=kwargs.pop("proj_after_attn",False),stochastic_depth_prob=stochastic_depth_prob,**kwargs,)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel
[docs]classMViT_V1_B_Weights(WeightsEnum):KINETICS400_V1=Weights(url="https://download.pytorch.org/models/mvit_v1_b-dbeb1030.pth",transforms=partial(VideoClassification,crop_size=(224,224),resize_size=(256,),mean=(0.45,0.45,0.45),std=(0.225,0.225,0.225),),meta={"min_size":(224,224),"min_temporal_size":16,"categories":_KINETICS400_CATEGORIES,"recipe":"https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md","_docs":("The weights were ported from the paper. The accuracies are estimated on video-level ""with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"),"num_params":36610672,"_metrics":{"Kinetics-400":{"acc@1":78.477,"acc@5":93.582,}},"_ops":70.599,"_file_size":139.764,},)DEFAULT=KINETICS400_V1
[docs]classMViT_V2_S_Weights(WeightsEnum):KINETICS400_V1=Weights(url="https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth",transforms=partial(VideoClassification,crop_size=(224,224),resize_size=(256,),mean=(0.45,0.45,0.45),std=(0.225,0.225,0.225),),meta={"min_size":(224,224),"min_temporal_size":16,"categories":_KINETICS400_CATEGORIES,"recipe":"https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md","_docs":("The weights were ported from the paper. The accuracies are estimated on video-level ""with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"),"num_params":34537744,"_metrics":{"Kinetics-400":{"acc@1":80.757,"acc@5":94.665,}},"_ops":64.224,"_file_size":131.884,},)DEFAULT=KINETICS400_V1
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",MViT_V1_B_Weights.KINETICS400_V1))defmvit_v1_b(*,weights:Optional[MViT_V1_B_Weights]=None,progress:bool=True,**kwargs:Any)->MViT:""" Constructs a base MViTV1 architecture from `Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__. .. betastatus:: video module Args: weights (:class:`~torchvision.models.video.MViT_V1_B_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.video.MViT_V1_B_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.video.MViT`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_ for more details about this class. .. autoclass:: torchvision.models.video.MViT_V1_B_Weights :members: """weights=MViT_V1_B_Weights.verify(weights)config:Dict[str,List]={"num_heads":[1,2,2,4,4,4,4,4,4,4,4,4,4,4,8,8],"input_channels":[96,192,192,384,384,384,384,384,384,384,384,384,384,384,768,768],"output_channels":[192,192,384,384,384,384,384,384,384,384,384,384,384,768,768,768],"kernel_q":[[],[3,3,3],[],[3,3,3],[],[],[],[],[],[],[],[],[],[],[3,3,3],[]],"kernel_kv":[[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],],"stride_q":[[],[1,2,2],[],[1,2,2],[],[],[],[],[],[],[],[],[],[],[1,2,2],[]],"stride_kv":[[1,8,8],[1,4,4],[1,4,4],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,1,1],[1,1,1],],}block_setting=[]foriinrange(len(config["num_heads"])):block_setting.append(MSBlockConfig(num_heads=config["num_heads"][i],input_channels=config["input_channels"][i],output_channels=config["output_channels"][i],kernel_q=config["kernel_q"][i],kernel_kv=config["kernel_kv"][i],stride_q=config["stride_q"][i],stride_kv=config["stride_kv"][i],))return_mvit(spatial_size=(224,224),temporal_size=16,block_setting=block_setting,residual_pool=False,residual_with_cls_embed=False,stochastic_depth_prob=kwargs.pop("stochastic_depth_prob",0.2),weights=weights,progress=progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",MViT_V2_S_Weights.KINETICS400_V1))defmvit_v2_s(*,weights:Optional[MViT_V2_S_Weights]=None,progress:bool=True,**kwargs:Any)->MViT:"""Constructs a small MViTV2 architecture from `Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__ and `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection <https://arxiv.org/abs/2112.01526>`__. .. betastatus:: video module Args: weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.video.MViT_V2_S_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.video.MViT`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_ for more details about this class. .. autoclass:: torchvision.models.video.MViT_V2_S_Weights :members: """weights=MViT_V2_S_Weights.verify(weights)config:Dict[str,List]={"num_heads":[1,2,2,4,4,4,4,4,4,4,4,4,4,4,8,8],"input_channels":[96,96,192,192,384,384,384,384,384,384,384,384,384,384,384,768],"output_channels":[96,192,192,384,384,384,384,384,384,384,384,384,384,384,768,768],"kernel_q":[[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],],"kernel_kv":[[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],[3,3,3],],"stride_q":[[1,1,1],[1,2,2],[1,1,1],[1,2,2],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,2,2],[1,1,1],],"stride_kv":[[1,8,8],[1,4,4],[1,4,4],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,2,2],[1,1,1],[1,1,1],],}block_setting=[]foriinrange(len(config["num_heads"])):block_setting.append(MSBlockConfig(num_heads=config["num_heads"][i],input_channels=config["input_channels"][i],output_channels=config["output_channels"][i],kernel_q=config["kernel_q"][i],kernel_kv=config["kernel_kv"][i],stride_q=config["stride_q"][i],stride_kv=config["stride_kv"][i],))return_mvit(spatial_size=(224,224),temporal_size=16,block_setting=block_setting,residual_pool=True,residual_with_cls_embed=False,rel_pos_embed=True,proj_after_attn=True,stochastic_depth_prob=kwargs.pop("stochastic_depth_prob",0.2),weights=weights,progress=progress,**kwargs,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.