Source code for torchvision.models.swin_transformer
importmathfromfunctoolsimportpartialfromtypingimportAny,Callable,List,Optionalimporttorchimporttorch.nn.functionalasFfromtorchimportnn,Tensorfrom..ops.miscimportMLP,Permutefrom..ops.stochastic_depthimportStochasticDepthfrom..transforms._presetsimportImageClassification,InterpolationModefrom..utilsimport_log_api_usage_oncefrom._apiimportregister_model,Weights,WeightsEnumfrom._metaimport_IMAGENET_CATEGORIESfrom._utilsimport_ovewrite_named_param,handle_legacy_interface__all__=["SwinTransformer","Swin_T_Weights","Swin_S_Weights","Swin_B_Weights","Swin_V2_T_Weights","Swin_V2_S_Weights","Swin_V2_B_Weights","swin_t","swin_s","swin_b","swin_v2_t","swin_v2_s","swin_v2_b",]def_patch_merging_pad(x:torch.Tensor)->torch.Tensor:H,W,_=x.shape[-3:]x=F.pad(x,(0,0,0,W%2,0,H%2))x0=x[...,0::2,0::2,:]# ... H/2 W/2 Cx1=x[...,1::2,0::2,:]# ... H/2 W/2 Cx2=x[...,0::2,1::2,:]# ... H/2 W/2 Cx3=x[...,1::2,1::2,:]# ... H/2 W/2 Cx=torch.cat([x0,x1,x2,x3],-1)# ... H/2 W/2 4*Creturnxtorch.fx.wrap("_patch_merging_pad")def_get_relative_position_bias(relative_position_bias_table:torch.Tensor,relative_position_index:torch.Tensor,window_size:List[int])->torch.Tensor:N=window_size[0]*window_size[1]relative_position_bias=relative_position_bias_table[relative_position_index]# type: ignore[index]relative_position_bias=relative_position_bias.view(N,N,-1)relative_position_bias=relative_position_bias.permute(2,0,1).contiguous().unsqueeze(0)returnrelative_position_biastorch.fx.wrap("_get_relative_position_bias")classPatchMerging(nn.Module):"""Patch Merging Layer. Args: dim (int): Number of input channels. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. """def__init__(self,dim:int,norm_layer:Callable[...,nn.Module]=nn.LayerNorm):super().__init__()_log_api_usage_once(self)self.dim=dimself.reduction=nn.Linear(4*dim,2*dim,bias=False)self.norm=norm_layer(4*dim)defforward(self,x:Tensor):""" Args: x (Tensor): input tensor with expected layout of [..., H, W, C] Returns: Tensor with layout of [..., H/2, W/2, 2*C] """x=_patch_merging_pad(x)x=self.norm(x)x=self.reduction(x)# ... H/2 W/2 2*CreturnxclassPatchMergingV2(nn.Module):"""Patch Merging Layer for Swin Transformer V2. Args: dim (int): Number of input channels. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. """def__init__(self,dim:int,norm_layer:Callable[...,nn.Module]=nn.LayerNorm):super().__init__()_log_api_usage_once(self)self.dim=dimself.reduction=nn.Linear(4*dim,2*dim,bias=False)self.norm=norm_layer(2*dim)# differencedefforward(self,x:Tensor):""" Args: x (Tensor): input tensor with expected layout of [..., H, W, C] Returns: Tensor with layout of [..., H/2, W/2, 2*C] """x=_patch_merging_pad(x)x=self.reduction(x)# ... H/2 W/2 2*Cx=self.norm(x)returnxdefshifted_window_attention(input:Tensor,qkv_weight:Tensor,proj_weight:Tensor,relative_position_bias:Tensor,window_size:List[int],num_heads:int,shift_size:List[int],attention_dropout:float=0.0,dropout:float=0.0,qkv_bias:Optional[Tensor]=None,proj_bias:Optional[Tensor]=None,logit_scale:Optional[torch.Tensor]=None,training:bool=True,)->Tensor:""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. relative_position_bias (Tensor): The learned relative position bias added to attention. window_size (List[int]): Window size. num_heads (int): Number of attention heads. shift_size (List[int]): Shift size for shifted window attention. attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. training (bool, optional): Training flag used by the dropout parameters. Default: True. Returns: Tensor[N, H, W, C]: The output tensor after shifted window attention. """B,H,W,C=input.shape# pad feature maps to multiples of window sizepad_r=(window_size[1]-W%window_size[1])%window_size[1]pad_b=(window_size[0]-H%window_size[0])%window_size[0]x=F.pad(input,(0,0,0,pad_r,0,pad_b))_,pad_H,pad_W,_=x.shapeshift_size=shift_size.copy()# If window size is larger than feature size, there is no need to shift windowifwindow_size[0]>=pad_H:shift_size[0]=0ifwindow_size[1]>=pad_W:shift_size[1]=0# cyclic shiftifsum(shift_size)>0:x=torch.roll(x,shifts=(-shift_size[0],-shift_size[1]),dims=(1,2))# partition windowsnum_windows=(pad_H//window_size[0])*(pad_W//window_size[1])x=x.view(B,pad_H//window_size[0],window_size[0],pad_W//window_size[1],window_size[1],C)x=x.permute(0,1,3,2,4,5).reshape(B*num_windows,window_size[0]*window_size[1],C)# B*nW, Ws*Ws, C# multi-head attentioniflogit_scaleisnotNoneandqkv_biasisnotNone:qkv_bias=qkv_bias.clone()length=qkv_bias.numel()//3qkv_bias[length:2*length].zero_()qkv=F.linear(x,qkv_weight,qkv_bias)qkv=qkv.reshape(x.size(0),x.size(1),3,num_heads,C//num_heads).permute(2,0,3,1,4)q,k,v=qkv[0],qkv[1],qkv[2]iflogit_scaleisnotNone:# cosine attentionattn=F.normalize(q,dim=-1)@F.normalize(k,dim=-1).transpose(-2,-1)logit_scale=torch.clamp(logit_scale,max=math.log(100.0)).exp()attn=attn*logit_scaleelse:q=q*(C//num_heads)**-0.5attn=q.matmul(k.transpose(-2,-1))# add relative position biasattn=attn+relative_position_biasifsum(shift_size)>0:# generate attention maskattn_mask=x.new_zeros((pad_H,pad_W))h_slices=((0,-window_size[0]),(-window_size[0],-shift_size[0]),(-shift_size[0],None))w_slices=((0,-window_size[1]),(-window_size[1],-shift_size[1]),(-shift_size[1],None))count=0forhinh_slices:forwinw_slices:attn_mask[h[0]:h[1],w[0]:w[1]]=countcount+=1attn_mask=attn_mask.view(pad_H//window_size[0],window_size[0],pad_W//window_size[1],window_size[1])attn_mask=attn_mask.permute(0,2,1,3).reshape(num_windows,window_size[0]*window_size[1])attn_mask=attn_mask.unsqueeze(1)-attn_mask.unsqueeze(2)attn_mask=attn_mask.masked_fill(attn_mask!=0,float(-100.0)).masked_fill(attn_mask==0,float(0.0))attn=attn.view(x.size(0)//num_windows,num_windows,num_heads,x.size(1),x.size(1))attn=attn+attn_mask.unsqueeze(1).unsqueeze(0)attn=attn.view(-1,num_heads,x.size(1),x.size(1))attn=F.softmax(attn,dim=-1)attn=F.dropout(attn,p=attention_dropout,training=training)x=attn.matmul(v).transpose(1,2).reshape(x.size(0),x.size(1),C)x=F.linear(x,proj_weight,proj_bias)x=F.dropout(x,p=dropout,training=training)# reverse windowsx=x.view(B,pad_H//window_size[0],pad_W//window_size[1],window_size[0],window_size[1],C)x=x.permute(0,1,3,2,4,5).reshape(B,pad_H,pad_W,C)# reverse cyclic shiftifsum(shift_size)>0:x=torch.roll(x,shifts=(shift_size[0],shift_size[1]),dims=(1,2))# unpad featuresx=x[:,:H,:W,:].contiguous()returnxtorch.fx.wrap("shifted_window_attention")classShiftedWindowAttention(nn.Module):""" See :func:`shifted_window_attention`. """def__init__(self,dim:int,window_size:List[int],shift_size:List[int],num_heads:int,qkv_bias:bool=True,proj_bias:bool=True,attention_dropout:float=0.0,dropout:float=0.0,):super().__init__()iflen(window_size)!=2orlen(shift_size)!=2:raiseValueError("window_size and shift_size must be of length 2")self.window_size=window_sizeself.shift_size=shift_sizeself.num_heads=num_headsself.attention_dropout=attention_dropoutself.dropout=dropoutself.qkv=nn.Linear(dim,dim*3,bias=qkv_bias)self.proj=nn.Linear(dim,dim,bias=proj_bias)self.define_relative_position_bias_table()self.define_relative_position_index()defdefine_relative_position_bias_table(self):# define a parameter table of relative position biasself.relative_position_bias_table=nn.Parameter(torch.zeros((2*self.window_size[0]-1)*(2*self.window_size[1]-1),self.num_heads))# 2*Wh-1 * 2*Ww-1, nHnn.init.trunc_normal_(self.relative_position_bias_table,std=0.02)defdefine_relative_position_index(self):# get pair-wise relative position index for each token inside the windowcoords_h=torch.arange(self.window_size[0])coords_w=torch.arange(self.window_size[1])coords=torch.stack(torch.meshgrid(coords_h,coords_w,indexing="ij"))# 2, Wh, Wwcoords_flatten=torch.flatten(coords,1)# 2, Wh*Wwrelative_coords=coords_flatten[:,:,None]-coords_flatten[:,None,:]# 2, Wh*Ww, Wh*Wwrelative_coords=relative_coords.permute(1,2,0).contiguous()# Wh*Ww, Wh*Ww, 2relative_coords[:,:,0]+=self.window_size[0]-1# shift to start from 0relative_coords[:,:,1]+=self.window_size[1]-1relative_coords[:,:,0]*=2*self.window_size[1]-1relative_position_index=relative_coords.sum(-1).flatten()# Wh*Ww*Wh*Wwself.register_buffer("relative_position_index",relative_position_index)defget_relative_position_bias(self)->torch.Tensor:return_get_relative_position_bias(self.relative_position_bias_table,self.relative_position_index,self.window_size# type: ignore[arg-type])defforward(self,x:Tensor)->Tensor:""" Args: x (Tensor): Tensor with layout of [B, H, W, C] Returns: Tensor with same layout as input, i.e. [B, H, W, C] """relative_position_bias=self.get_relative_position_bias()returnshifted_window_attention(x,self.qkv.weight,self.proj.weight,relative_position_bias,self.window_size,self.num_heads,shift_size=self.shift_size,attention_dropout=self.attention_dropout,dropout=self.dropout,qkv_bias=self.qkv.bias,proj_bias=self.proj.bias,training=self.training,)classShiftedWindowAttentionV2(ShiftedWindowAttention):""" See :func:`shifted_window_attention_v2`. """def__init__(self,dim:int,window_size:List[int],shift_size:List[int],num_heads:int,qkv_bias:bool=True,proj_bias:bool=True,attention_dropout:float=0.0,dropout:float=0.0,):super().__init__(dim,window_size,shift_size,num_heads,qkv_bias=qkv_bias,proj_bias=proj_bias,attention_dropout=attention_dropout,dropout=dropout,)self.logit_scale=nn.Parameter(torch.log(10*torch.ones((num_heads,1,1))))# mlp to generate continuous relative position biasself.cpb_mlp=nn.Sequential(nn.Linear(2,512,bias=True),nn.ReLU(inplace=True),nn.Linear(512,num_heads,bias=False))ifqkv_bias:length=self.qkv.bias.numel()//3self.qkv.bias[length:2*length].data.zero_()defdefine_relative_position_bias_table(self):# get relative_coords_tablerelative_coords_h=torch.arange(-(self.window_size[0]-1),self.window_size[0],dtype=torch.float32)relative_coords_w=torch.arange(-(self.window_size[1]-1),self.window_size[1],dtype=torch.float32)relative_coords_table=torch.stack(torch.meshgrid([relative_coords_h,relative_coords_w],indexing="ij"))relative_coords_table=relative_coords_table.permute(1,2,0).contiguous().unsqueeze(0)# 1, 2*Wh-1, 2*Ww-1, 2relative_coords_table[:,:,:,0]/=self.window_size[0]-1relative_coords_table[:,:,:,1]/=self.window_size[1]-1relative_coords_table*=8# normalize to -8, 8relative_coords_table=(torch.sign(relative_coords_table)*torch.log2(torch.abs(relative_coords_table)+1.0)/3.0)self.register_buffer("relative_coords_table",relative_coords_table)defget_relative_position_bias(self)->torch.Tensor:relative_position_bias=_get_relative_position_bias(self.cpb_mlp(self.relative_coords_table).view(-1,self.num_heads),self.relative_position_index,# type: ignore[arg-type]self.window_size,)relative_position_bias=16*torch.sigmoid(relative_position_bias)returnrelative_position_biasdefforward(self,x:Tensor):""" Args: x (Tensor): Tensor with layout of [B, H, W, C] Returns: Tensor with same layout as input, i.e. [B, H, W, C] """relative_position_bias=self.get_relative_position_bias()returnshifted_window_attention(x,self.qkv.weight,self.proj.weight,relative_position_bias,self.window_size,self.num_heads,shift_size=self.shift_size,attention_dropout=self.attention_dropout,dropout=self.dropout,qkv_bias=self.qkv.bias,proj_bias=self.proj.bias,logit_scale=self.logit_scale,training=self.training,)classSwinTransformerBlock(nn.Module):""" Swin Transformer Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (List[int]): Window size. shift_size (List[int]): Shift size for shifted window attention. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention """def__init__(self,dim:int,num_heads:int,window_size:List[int],shift_size:List[int],mlp_ratio:float=4.0,dropout:float=0.0,attention_dropout:float=0.0,stochastic_depth_prob:float=0.0,norm_layer:Callable[...,nn.Module]=nn.LayerNorm,attn_layer:Callable[...,nn.Module]=ShiftedWindowAttention,):super().__init__()_log_api_usage_once(self)self.norm1=norm_layer(dim)self.attn=attn_layer(dim,window_size,shift_size,num_heads,attention_dropout=attention_dropout,dropout=dropout,)self.stochastic_depth=StochasticDepth(stochastic_depth_prob,"row")self.norm2=norm_layer(dim)self.mlp=MLP(dim,[int(dim*mlp_ratio),dim],activation_layer=nn.GELU,inplace=None,dropout=dropout)forminself.mlp.modules():ifisinstance(m,nn.Linear):nn.init.xavier_uniform_(m.weight)ifm.biasisnotNone:nn.init.normal_(m.bias,std=1e-6)defforward(self,x:Tensor):x=x+self.stochastic_depth(self.attn(self.norm1(x)))x=x+self.stochastic_depth(self.mlp(self.norm2(x)))returnxclassSwinTransformerBlockV2(SwinTransformerBlock):""" Swin Transformer V2 Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (List[int]): Window size. shift_size (List[int]): Shift size for shifted window attention. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2. """def__init__(self,dim:int,num_heads:int,window_size:List[int],shift_size:List[int],mlp_ratio:float=4.0,dropout:float=0.0,attention_dropout:float=0.0,stochastic_depth_prob:float=0.0,norm_layer:Callable[...,nn.Module]=nn.LayerNorm,attn_layer:Callable[...,nn.Module]=ShiftedWindowAttentionV2,):super().__init__(dim,num_heads,window_size,shift_size,mlp_ratio=mlp_ratio,dropout=dropout,attention_dropout=attention_dropout,stochastic_depth_prob=stochastic_depth_prob,norm_layer=norm_layer,attn_layer=attn_layer,)defforward(self,x:Tensor):# Here is the difference, we apply norm after the attention in V2.# In V1 we applied norm before the attention.x=x+self.stochastic_depth(self.norm1(self.attn(x)))x=x+self.stochastic_depth(self.norm2(self.mlp(x)))returnxclassSwinTransformer(nn.Module):""" Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" <https://arxiv.org/abs/2103.14030>`_ paper. Args: patch_size (List[int]): Patch size. embed_dim (int): Patch embedding dimension. depths (List(int)): Depth of each Swin Transformer layer. num_heads (List(int)): Number of attention heads in different layers. window_size (List[int]): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging. """def__init__(self,patch_size:List[int],embed_dim:int,depths:List[int],num_heads:List[int],window_size:List[int],mlp_ratio:float=4.0,dropout:float=0.0,attention_dropout:float=0.0,stochastic_depth_prob:float=0.1,num_classes:int=1000,norm_layer:Optional[Callable[...,nn.Module]]=None,block:Optional[Callable[...,nn.Module]]=None,downsample_layer:Callable[...,nn.Module]=PatchMerging,):super().__init__()_log_api_usage_once(self)self.num_classes=num_classesifblockisNone:block=SwinTransformerBlockifnorm_layerisNone:norm_layer=partial(nn.LayerNorm,eps=1e-5)layers:List[nn.Module]=[]# split image into non-overlapping patcheslayers.append(nn.Sequential(nn.Conv2d(3,embed_dim,kernel_size=(patch_size[0],patch_size[1]),stride=(patch_size[0],patch_size[1])),Permute([0,2,3,1]),norm_layer(embed_dim),))total_stage_blocks=sum(depths)stage_block_id=0# build SwinTransformer blocksfori_stageinrange(len(depths)):stage:List[nn.Module]=[]dim=embed_dim*2**i_stagefori_layerinrange(depths[i_stage]):# adjust stochastic depth probability based on the depth of the stage blocksd_prob=stochastic_depth_prob*float(stage_block_id)/(total_stage_blocks-1)stage.append(block(dim,num_heads[i_stage],window_size=window_size,shift_size=[0ifi_layer%2==0elsew//2forwinwindow_size],mlp_ratio=mlp_ratio,dropout=dropout,attention_dropout=attention_dropout,stochastic_depth_prob=sd_prob,norm_layer=norm_layer,))stage_block_id+=1layers.append(nn.Sequential(*stage))# add patch merging layerifi_stage<(len(depths)-1):layers.append(downsample_layer(dim,norm_layer))self.features=nn.Sequential(*layers)num_features=embed_dim*2**(len(depths)-1)self.norm=norm_layer(num_features)self.permute=Permute([0,3,1,2])# B H W C -> B C H Wself.avgpool=nn.AdaptiveAvgPool2d(1)self.flatten=nn.Flatten(1)self.head=nn.Linear(num_features,num_classes)forminself.modules():ifisinstance(m,nn.Linear):nn.init.trunc_normal_(m.weight,std=0.02)ifm.biasisnotNone:nn.init.zeros_(m.bias)defforward(self,x):x=self.features(x)x=self.norm(x)x=self.permute(x)x=self.avgpool(x)x=self.flatten(x)x=self.head(x)returnxdef_swin_transformer(patch_size:List[int],embed_dim:int,depths:List[int],num_heads:List[int],window_size:List[int],stochastic_depth_prob:float,weights:Optional[WeightsEnum],progress:bool,**kwargs:Any,)->SwinTransformer:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=SwinTransformer(patch_size=patch_size,embed_dim=embed_dim,depths=depths,num_heads=num_heads,window_size=window_size,stochastic_depth_prob=stochastic_depth_prob,**kwargs,)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel_COMMON_META={"categories":_IMAGENET_CATEGORIES,}
[docs]classSwin_T_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/swin_t-704ceda3.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232,interpolation=InterpolationMode.BICUBIC),meta={**_COMMON_META,"num_params":28288354,"min_size":(224,224),"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#swintransformer","_metrics":{"ImageNet-1K":{"acc@1":81.474,"acc@5":95.776,}},"_ops":4.491,"_file_size":108.19,"_docs":"""These weights reproduce closely the results of the paper using a similar training recipe.""",},)DEFAULT=IMAGENET1K_V1
[docs]classSwin_S_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/swin_s-5e29d889.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=246,interpolation=InterpolationMode.BICUBIC),meta={**_COMMON_META,"num_params":49606258,"min_size":(224,224),"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#swintransformer","_metrics":{"ImageNet-1K":{"acc@1":83.196,"acc@5":96.360,}},"_ops":8.741,"_file_size":189.786,"_docs":"""These weights reproduce closely the results of the paper using a similar training recipe.""",},)DEFAULT=IMAGENET1K_V1
[docs]classSwin_B_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/swin_b-68c6b09e.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=238,interpolation=InterpolationMode.BICUBIC),meta={**_COMMON_META,"num_params":87768224,"min_size":(224,224),"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#swintransformer","_metrics":{"ImageNet-1K":{"acc@1":83.582,"acc@5":96.640,}},"_ops":15.431,"_file_size":335.364,"_docs":"""These weights reproduce closely the results of the paper using a similar training recipe.""",},)DEFAULT=IMAGENET1K_V1
[docs]classSwin_V2_T_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth",transforms=partial(ImageClassification,crop_size=256,resize_size=260,interpolation=InterpolationMode.BICUBIC),meta={**_COMMON_META,"num_params":28351570,"min_size":(256,256),"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2","_metrics":{"ImageNet-1K":{"acc@1":82.072,"acc@5":96.132,}},"_ops":5.94,"_file_size":108.626,"_docs":"""These weights reproduce closely the results of the paper using a similar training recipe.""",},)DEFAULT=IMAGENET1K_V1
[docs]classSwin_V2_S_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth",transforms=partial(ImageClassification,crop_size=256,resize_size=260,interpolation=InterpolationMode.BICUBIC),meta={**_COMMON_META,"num_params":49737442,"min_size":(256,256),"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2","_metrics":{"ImageNet-1K":{"acc@1":83.712,"acc@5":96.816,}},"_ops":11.546,"_file_size":190.675,"_docs":"""These weights reproduce closely the results of the paper using a similar training recipe.""",},)DEFAULT=IMAGENET1K_V1
[docs]classSwin_V2_B_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/swin_v2_b-781e5279.pth",transforms=partial(ImageClassification,crop_size=256,resize_size=272,interpolation=InterpolationMode.BICUBIC),meta={**_COMMON_META,"num_params":87930848,"min_size":(256,256),"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2","_metrics":{"ImageNet-1K":{"acc@1":84.112,"acc@5":96.864,}},"_ops":20.325,"_file_size":336.372,"_docs":"""These weights reproduce closely the results of the paper using a similar training recipe.""",},)DEFAULT=IMAGENET1K_V1
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Swin_T_Weights.IMAGENET1K_V1))defswin_t(*,weights:Optional[Swin_T_Weights]=None,progress:bool=True,**kwargs:Any)->SwinTransformer:""" Constructs a swin_tiny architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_. Args: weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_T_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.Swin_T_Weights :members: """weights=Swin_T_Weights.verify(weights)return_swin_transformer(patch_size=[4,4],embed_dim=96,depths=[2,2,6,2],num_heads=[3,6,12,24],window_size=[7,7],stochastic_depth_prob=0.2,weights=weights,progress=progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Swin_S_Weights.IMAGENET1K_V1))defswin_s(*,weights:Optional[Swin_S_Weights]=None,progress:bool=True,**kwargs:Any)->SwinTransformer:""" Constructs a swin_small architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_. Args: weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_S_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.Swin_S_Weights :members: """weights=Swin_S_Weights.verify(weights)return_swin_transformer(patch_size=[4,4],embed_dim=96,depths=[2,2,18,2],num_heads=[3,6,12,24],window_size=[7,7],stochastic_depth_prob=0.3,weights=weights,progress=progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Swin_B_Weights.IMAGENET1K_V1))defswin_b(*,weights:Optional[Swin_B_Weights]=None,progress:bool=True,**kwargs:Any)->SwinTransformer:""" Constructs a swin_base architecture from `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_. Args: weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_B_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.Swin_B_Weights :members: """weights=Swin_B_Weights.verify(weights)return_swin_transformer(patch_size=[4,4],embed_dim=128,depths=[2,2,18,2],num_heads=[4,8,16,32],window_size=[7,7],stochastic_depth_prob=0.5,weights=weights,progress=progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Swin_V2_T_Weights.IMAGENET1K_V1))defswin_v2_t(*,weights:Optional[Swin_V2_T_Weights]=None,progress:bool=True,**kwargs:Any)->SwinTransformer:""" Constructs a swin_v2_tiny architecture from `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_. Args: weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_V2_T_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.Swin_V2_T_Weights :members: """weights=Swin_V2_T_Weights.verify(weights)return_swin_transformer(patch_size=[4,4],embed_dim=96,depths=[2,2,6,2],num_heads=[3,6,12,24],window_size=[8,8],stochastic_depth_prob=0.2,weights=weights,progress=progress,block=SwinTransformerBlockV2,downsample_layer=PatchMergingV2,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Swin_V2_S_Weights.IMAGENET1K_V1))defswin_v2_s(*,weights:Optional[Swin_V2_S_Weights]=None,progress:bool=True,**kwargs:Any)->SwinTransformer:""" Constructs a swin_v2_small architecture from `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_. Args: weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_V2_S_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.Swin_V2_S_Weights :members: """weights=Swin_V2_S_Weights.verify(weights)return_swin_transformer(patch_size=[4,4],embed_dim=96,depths=[2,2,18,2],num_heads=[3,6,12,24],window_size=[8,8],stochastic_depth_prob=0.3,weights=weights,progress=progress,block=SwinTransformerBlockV2,downsample_layer=PatchMergingV2,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Swin_V2_B_Weights.IMAGENET1K_V1))defswin_v2_b(*,weights:Optional[Swin_V2_B_Weights]=None,progress:bool=True,**kwargs:Any)->SwinTransformer:""" Constructs a swin_v2_base architecture from `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_. Args: weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Swin_V2_B_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.Swin_V2_B_Weights :members: """weights=Swin_V2_B_Weights.verify(weights)return_swin_transformer(patch_size=[4,4],embed_dim=128,depths=[2,2,18,2],num_heads=[4,8,16,32],window_size=[8,8],stochastic_depth_prob=0.5,weights=weights,progress=progress,block=SwinTransformerBlockV2,downsample_layer=PatchMergingV2,**kwargs,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.