Source code for torchvision.models.vision_transformer
importmathfromcollectionsimportOrderedDictfromfunctoolsimportpartialfromtypingimportAny,Callable,Dict,List,NamedTuple,Optionalimporttorchimporttorch.nnasnnfrom..ops.miscimportConv2dNormActivation,MLPfrom..transforms._presetsimportImageClassification,InterpolationModefrom..utilsimport_log_api_usage_oncefrom._apiimportregister_model,Weights,WeightsEnumfrom._metaimport_IMAGENET_CATEGORIESfrom._utilsimport_ovewrite_named_param,handle_legacy_interface__all__=["VisionTransformer","ViT_B_16_Weights","ViT_B_32_Weights","ViT_L_16_Weights","ViT_L_32_Weights","ViT_H_14_Weights","vit_b_16","vit_b_32","vit_l_16","vit_l_32","vit_h_14",]classConvStemConfig(NamedTuple):out_channels:intkernel_size:intstride:intnorm_layer:Callable[...,nn.Module]=nn.BatchNorm2dactivation_layer:Callable[...,nn.Module]=nn.ReLUclassMLPBlock(MLP):"""Transformer MLP block."""_version=2def__init__(self,in_dim:int,mlp_dim:int,dropout:float):super().__init__(in_dim,[mlp_dim,in_dim],activation_layer=nn.GELU,inplace=None,dropout=dropout)forminself.modules():ifisinstance(m,nn.Linear):nn.init.xavier_uniform_(m.weight)ifm.biasisnotNone:nn.init.normal_(m.bias,std=1e-6)def_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,):version=local_metadata.get("version",None)ifversionisNoneorversion<2:# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053foriinrange(2):fortypein["weight","bias"]:old_key=f"{prefix}linear_{i+1}.{type}"new_key=f"{prefix}{3*i}.{type}"ifold_keyinstate_dict:state_dict[new_key]=state_dict.pop(old_key)super()._load_from_state_dict(state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,)classEncoderBlock(nn.Module):"""Transformer encoder block."""def__init__(self,num_heads:int,hidden_dim:int,mlp_dim:int,dropout:float,attention_dropout:float,norm_layer:Callable[...,torch.nn.Module]=partial(nn.LayerNorm,eps=1e-6),):super().__init__()self.num_heads=num_heads# Attention blockself.ln_1=norm_layer(hidden_dim)self.self_attention=nn.MultiheadAttention(hidden_dim,num_heads,dropout=attention_dropout,batch_first=True)self.dropout=nn.Dropout(dropout)# MLP blockself.ln_2=norm_layer(hidden_dim)self.mlp=MLPBlock(hidden_dim,mlp_dim,dropout)defforward(self,input:torch.Tensor):torch._assert(input.dim()==3,f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")x=self.ln_1(input)x,_=self.self_attention(x,x,x,need_weights=False)x=self.dropout(x)x=x+inputy=self.ln_2(x)y=self.mlp(y)returnx+yclassEncoder(nn.Module):"""Transformer Model Encoder for sequence to sequence translation."""def__init__(self,seq_length:int,num_layers:int,num_heads:int,hidden_dim:int,mlp_dim:int,dropout:float,attention_dropout:float,norm_layer:Callable[...,torch.nn.Module]=partial(nn.LayerNorm,eps=1e-6),):super().__init__()# Note that batch_size is on the first dim because# we have batch_first=True in nn.MultiAttention() by defaultself.pos_embedding=nn.Parameter(torch.empty(1,seq_length,hidden_dim).normal_(std=0.02))# from BERTself.dropout=nn.Dropout(dropout)layers:OrderedDict[str,nn.Module]=OrderedDict()foriinrange(num_layers):layers[f"encoder_layer_{i}"]=EncoderBlock(num_heads,hidden_dim,mlp_dim,dropout,attention_dropout,norm_layer,)self.layers=nn.Sequential(layers)self.ln=norm_layer(hidden_dim)defforward(self,input:torch.Tensor):torch._assert(input.dim()==3,f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")input=input+self.pos_embeddingreturnself.ln(self.layers(self.dropout(input)))classVisionTransformer(nn.Module):"""Vision Transformer as per https://arxiv.org/abs/2010.11929."""def__init__(self,image_size:int,patch_size:int,num_layers:int,num_heads:int,hidden_dim:int,mlp_dim:int,dropout:float=0.0,attention_dropout:float=0.0,num_classes:int=1000,representation_size:Optional[int]=None,norm_layer:Callable[...,torch.nn.Module]=partial(nn.LayerNorm,eps=1e-6),conv_stem_configs:Optional[List[ConvStemConfig]]=None,):super().__init__()_log_api_usage_once(self)torch._assert(image_size%patch_size==0,"Input shape indivisible by patch size!")self.image_size=image_sizeself.patch_size=patch_sizeself.hidden_dim=hidden_dimself.mlp_dim=mlp_dimself.attention_dropout=attention_dropoutself.dropout=dropoutself.num_classes=num_classesself.representation_size=representation_sizeself.norm_layer=norm_layerifconv_stem_configsisnotNone:# As per https://arxiv.org/abs/2106.14881seq_proj=nn.Sequential()prev_channels=3fori,conv_stem_layer_configinenumerate(conv_stem_configs):seq_proj.add_module(f"conv_bn_relu_{i}",Conv2dNormActivation(in_channels=prev_channels,out_channels=conv_stem_layer_config.out_channels,kernel_size=conv_stem_layer_config.kernel_size,stride=conv_stem_layer_config.stride,norm_layer=conv_stem_layer_config.norm_layer,activation_layer=conv_stem_layer_config.activation_layer,),)prev_channels=conv_stem_layer_config.out_channelsseq_proj.add_module("conv_last",nn.Conv2d(in_channels=prev_channels,out_channels=hidden_dim,kernel_size=1))self.conv_proj:nn.Module=seq_projelse:self.conv_proj=nn.Conv2d(in_channels=3,out_channels=hidden_dim,kernel_size=patch_size,stride=patch_size)seq_length=(image_size//patch_size)**2# Add a class tokenself.class_token=nn.Parameter(torch.zeros(1,1,hidden_dim))seq_length+=1self.encoder=Encoder(seq_length,num_layers,num_heads,hidden_dim,mlp_dim,dropout,attention_dropout,norm_layer,)self.seq_length=seq_lengthheads_layers:OrderedDict[str,nn.Module]=OrderedDict()ifrepresentation_sizeisNone:heads_layers["head"]=nn.Linear(hidden_dim,num_classes)else:heads_layers["pre_logits"]=nn.Linear(hidden_dim,representation_size)heads_layers["act"]=nn.Tanh()heads_layers["head"]=nn.Linear(representation_size,num_classes)self.heads=nn.Sequential(heads_layers)ifisinstance(self.conv_proj,nn.Conv2d):# Init the patchify stemfan_in=self.conv_proj.in_channels*self.conv_proj.kernel_size[0]*self.conv_proj.kernel_size[1]nn.init.trunc_normal_(self.conv_proj.weight,std=math.sqrt(1/fan_in))ifself.conv_proj.biasisnotNone:nn.init.zeros_(self.conv_proj.bias)elifself.conv_proj.conv_lastisnotNoneandisinstance(self.conv_proj.conv_last,nn.Conv2d):# Init the last 1x1 conv of the conv stemnn.init.normal_(self.conv_proj.conv_last.weight,mean=0.0,std=math.sqrt(2.0/self.conv_proj.conv_last.out_channels))ifself.conv_proj.conv_last.biasisnotNone:nn.init.zeros_(self.conv_proj.conv_last.bias)ifhasattr(self.heads,"pre_logits")andisinstance(self.heads.pre_logits,nn.Linear):fan_in=self.heads.pre_logits.in_featuresnn.init.trunc_normal_(self.heads.pre_logits.weight,std=math.sqrt(1/fan_in))nn.init.zeros_(self.heads.pre_logits.bias)ifisinstance(self.heads.head,nn.Linear):nn.init.zeros_(self.heads.head.weight)nn.init.zeros_(self.heads.head.bias)def_process_input(self,x:torch.Tensor)->torch.Tensor:n,c,h,w=x.shapep=self.patch_sizetorch._assert(h==self.image_size,f"Wrong image height! Expected {self.image_size} but got {h}!")torch._assert(w==self.image_size,f"Wrong image width! Expected {self.image_size} but got {w}!")n_h=h//pn_w=w//p# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)x=self.conv_proj(x)# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))x=x.reshape(n,self.hidden_dim,n_h*n_w)# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)# The self attention layer expects inputs in the format (N, S, E)# where S is the source sequence length, N is the batch size, E is the# embedding dimensionx=x.permute(0,2,1)returnxdefforward(self,x:torch.Tensor):# Reshape and permute the input tensorx=self._process_input(x)n=x.shape[0]# Expand the class token to the full batchbatch_class_token=self.class_token.expand(n,-1,-1)x=torch.cat([batch_class_token,x],dim=1)x=self.encoder(x)# Classifier "token" as used by standard language architecturesx=x[:,0]x=self.heads(x)returnxdef_vision_transformer(patch_size:int,num_layers:int,num_heads:int,hidden_dim:int,mlp_dim:int,weights:Optional[WeightsEnum],progress:bool,**kwargs:Any,)->VisionTransformer:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))assertweights.meta["min_size"][0]==weights.meta["min_size"][1]_ovewrite_named_param(kwargs,"image_size",weights.meta["min_size"][0])image_size=kwargs.pop("image_size",224)model=VisionTransformer(image_size=image_size,patch_size=patch_size,num_layers=num_layers,num_heads=num_heads,hidden_dim=hidden_dim,mlp_dim=mlp_dim,**kwargs,)ifweights:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel_COMMON_META:Dict[str,Any]={"categories":_IMAGENET_CATEGORIES,}_COMMON_SWAG_META={**_COMMON_META,"recipe":"https://github.com/facebookresearch/SWAG","license":"https://github.com/facebookresearch/SWAG/blob/main/LICENSE",}
[docs]classViT_B_16_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/vit_b_16-c867db91.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":86567656,"min_size":(224,224),"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16","_metrics":{"ImageNet-1K":{"acc@1":81.072,"acc@5":95.318,}},"_ops":17.564,"_file_size":330.285,"_docs":""" These weights were trained from scratch by using a modified version of `DeIT <https://arxiv.org/abs/2012.12877>`_'s training recipe. """,},)IMAGENET1K_SWAG_E2E_V1=Weights(url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",transforms=partial(ImageClassification,crop_size=384,resize_size=384,interpolation=InterpolationMode.BICUBIC,),meta={**_COMMON_SWAG_META,"num_params":86859496,"min_size":(384,384),"_metrics":{"ImageNet-1K":{"acc@1":85.304,"acc@5":97.650,}},"_ops":55.484,"_file_size":331.398,"_docs":""" These weights are learnt via transfer learning by end-to-end fine-tuning the original `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data. """,},)IMAGENET1K_SWAG_LINEAR_V1=Weights(url="https://download.pytorch.org/models/vit_b_16_lc_swag-4e70ced5.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=224,interpolation=InterpolationMode.BICUBIC,),meta={**_COMMON_SWAG_META,"recipe":"https://github.com/pytorch/vision/pull/5793","num_params":86567656,"min_size":(224,224),"_metrics":{"ImageNet-1K":{"acc@1":81.886,"acc@5":96.180,}},"_ops":17.564,"_file_size":330.285,"_docs":""" These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk weights and a linear classifier learnt on top of them trained on ImageNet-1K data. """,},)DEFAULT=IMAGENET1K_V1
[docs]classViT_B_32_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":88224232,"min_size":(224,224),"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32","_metrics":{"ImageNet-1K":{"acc@1":75.912,"acc@5":92.466,}},"_ops":4.409,"_file_size":336.604,"_docs":""" These weights were trained from scratch by using a modified version of `DeIT <https://arxiv.org/abs/2012.12877>`_'s training recipe. """,},)DEFAULT=IMAGENET1K_V1
[docs]classViT_L_16_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=242),meta={**_COMMON_META,"num_params":304326632,"min_size":(224,224),"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16","_metrics":{"ImageNet-1K":{"acc@1":79.662,"acc@5":94.638,}},"_ops":61.555,"_file_size":1161.023,"_docs":""" These weights were trained from scratch by using a modified version of TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)IMAGENET1K_SWAG_E2E_V1=Weights(url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",transforms=partial(ImageClassification,crop_size=512,resize_size=512,interpolation=InterpolationMode.BICUBIC,),meta={**_COMMON_SWAG_META,"num_params":305174504,"min_size":(512,512),"_metrics":{"ImageNet-1K":{"acc@1":88.064,"acc@5":98.512,}},"_ops":361.986,"_file_size":1164.258,"_docs":""" These weights are learnt via transfer learning by end-to-end fine-tuning the original `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data. """,},)IMAGENET1K_SWAG_LINEAR_V1=Weights(url="https://download.pytorch.org/models/vit_l_16_lc_swag-4d563306.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=224,interpolation=InterpolationMode.BICUBIC,),meta={**_COMMON_SWAG_META,"recipe":"https://github.com/pytorch/vision/pull/5793","num_params":304326632,"min_size":(224,224),"_metrics":{"ImageNet-1K":{"acc@1":85.146,"acc@5":97.422,}},"_ops":61.555,"_file_size":1161.023,"_docs":""" These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk weights and a linear classifier learnt on top of them trained on ImageNet-1K data. """,},)DEFAULT=IMAGENET1K_V1
[docs]classViT_L_32_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/vit_l_32-c7638314.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":306535400,"min_size":(224,224),"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32","_metrics":{"ImageNet-1K":{"acc@1":76.972,"acc@5":93.07,}},"_ops":15.378,"_file_size":1169.449,"_docs":""" These weights were trained from scratch by using a modified version of `DeIT <https://arxiv.org/abs/2012.12877>`_'s training recipe. """,},)DEFAULT=IMAGENET1K_V1
[docs]classViT_H_14_Weights(WeightsEnum):IMAGENET1K_SWAG_E2E_V1=Weights(url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",transforms=partial(ImageClassification,crop_size=518,resize_size=518,interpolation=InterpolationMode.BICUBIC,),meta={**_COMMON_SWAG_META,"num_params":633470440,"min_size":(518,518),"_metrics":{"ImageNet-1K":{"acc@1":88.552,"acc@5":98.694,}},"_ops":1016.717,"_file_size":2416.643,"_docs":""" These weights are learnt via transfer learning by end-to-end fine-tuning the original `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data. """,},)IMAGENET1K_SWAG_LINEAR_V1=Weights(url="https://download.pytorch.org/models/vit_h_14_lc_swag-c1eb923e.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=224,interpolation=InterpolationMode.BICUBIC,),meta={**_COMMON_SWAG_META,"recipe":"https://github.com/pytorch/vision/pull/5793","num_params":632045800,"min_size":(224,224),"_metrics":{"ImageNet-1K":{"acc@1":85.708,"acc@5":97.730,}},"_ops":167.295,"_file_size":2411.209,"_docs":""" These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk weights and a linear classifier learnt on top of them trained on ImageNet-1K data. """,},)DEFAULT=IMAGENET1K_SWAG_E2E_V1
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ViT_B_16_Weights.IMAGENET1K_V1))defvit_b_16(*,weights:Optional[ViT_B_16_Weights]=None,progress:bool=True,**kwargs:Any)->VisionTransformer:""" Constructs a vit_b_16 architecture from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_. Args: weights (:class:`~torchvision.models.ViT_B_16_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ViT_B_16_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.ViT_B_16_Weights :members: """weights=ViT_B_16_Weights.verify(weights)return_vision_transformer(patch_size=16,num_layers=12,num_heads=12,hidden_dim=768,mlp_dim=3072,weights=weights,progress=progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ViT_B_32_Weights.IMAGENET1K_V1))defvit_b_32(*,weights:Optional[ViT_B_32_Weights]=None,progress:bool=True,**kwargs:Any)->VisionTransformer:""" Constructs a vit_b_32 architecture from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_. Args: weights (:class:`~torchvision.models.ViT_B_32_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ViT_B_32_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.ViT_B_32_Weights :members: """weights=ViT_B_32_Weights.verify(weights)return_vision_transformer(patch_size=32,num_layers=12,num_heads=12,hidden_dim=768,mlp_dim=3072,weights=weights,progress=progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ViT_L_16_Weights.IMAGENET1K_V1))defvit_l_16(*,weights:Optional[ViT_L_16_Weights]=None,progress:bool=True,**kwargs:Any)->VisionTransformer:""" Constructs a vit_l_16 architecture from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_. Args: weights (:class:`~torchvision.models.ViT_L_16_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ViT_L_16_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.ViT_L_16_Weights :members: """weights=ViT_L_16_Weights.verify(weights)return_vision_transformer(patch_size=16,num_layers=24,num_heads=16,hidden_dim=1024,mlp_dim=4096,weights=weights,progress=progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ViT_L_32_Weights.IMAGENET1K_V1))defvit_l_32(*,weights:Optional[ViT_L_32_Weights]=None,progress:bool=True,**kwargs:Any)->VisionTransformer:""" Constructs a vit_l_32 architecture from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_. Args: weights (:class:`~torchvision.models.ViT_L_32_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ViT_L_32_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.ViT_L_32_Weights :members: """weights=ViT_L_32_Weights.verify(weights)return_vision_transformer(patch_size=32,num_layers=24,num_heads=16,hidden_dim=1024,mlp_dim=4096,weights=weights,progress=progress,**kwargs,)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",None))defvit_h_14(*,weights:Optional[ViT_H_14_Weights]=None,progress:bool=True,**kwargs:Any)->VisionTransformer:""" Constructs a vit_h_14 architecture from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_. Args: weights (:class:`~torchvision.models.ViT_H_14_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ViT_H_14_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_ for more details about this class. .. autoclass:: torchvision.models.ViT_H_14_Weights :members: """weights=ViT_H_14_Weights.verify(weights)return_vision_transformer(patch_size=14,num_layers=32,num_heads=16,hidden_dim=1280,mlp_dim=5120,weights=weights,progress=progress,**kwargs,)
definterpolate_embeddings(image_size:int,patch_size:int,model_state:"OrderedDict[str, torch.Tensor]",interpolation_mode:str="bicubic",reset_heads:bool=False,)->"OrderedDict[str, torch.Tensor]":"""This function helps interpolate positional embeddings during checkpoint loading, especially when you want to apply a pre-trained model on images with different resolution. Args: image_size (int): Image size of the new model. patch_size (int): Patch size of the new model. model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. reset_heads (bool): If true, not copying the state of heads. Default: False. Returns: OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. """# Shape of pos_embedding is (1, seq_length, hidden_dim)pos_embedding=model_state["encoder.pos_embedding"]n,seq_length,hidden_dim=pos_embedding.shapeifn!=1:raiseValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")new_seq_length=(image_size//patch_size)**2+1# Need to interpolate the weights for the position embedding.# We do this by reshaping the positions embeddings to a 2d grid, performing# an interpolation in the (h, w) space and then reshaping back to a 1d grid.ifnew_seq_length!=seq_length:# The class token embedding shouldn't be interpolated, so we split it up.seq_length-=1new_seq_length-=1pos_embedding_token=pos_embedding[:,:1,:]pos_embedding_img=pos_embedding[:,1:,:]# (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)pos_embedding_img=pos_embedding_img.permute(0,2,1)seq_length_1d=int(math.sqrt(seq_length))ifseq_length_1d*seq_length_1d!=seq_length:raiseValueError(f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d*seq_length_1d} and seq_length = {seq_length}")# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)pos_embedding_img=pos_embedding_img.reshape(1,hidden_dim,seq_length_1d,seq_length_1d)new_seq_length_1d=image_size//patch_size# Perform interpolation.# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)new_pos_embedding_img=nn.functional.interpolate(pos_embedding_img,size=new_seq_length_1d,mode=interpolation_mode,align_corners=True,)# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)new_pos_embedding_img=new_pos_embedding_img.reshape(1,hidden_dim,new_seq_length)# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)new_pos_embedding_img=new_pos_embedding_img.permute(0,2,1)new_pos_embedding=torch.cat([pos_embedding_token,new_pos_embedding_img],dim=1)model_state["encoder.pos_embedding"]=new_pos_embeddingifreset_heads:model_state_copy:"OrderedDict[str, torch.Tensor]"=OrderedDict()fork,vinmodel_state.items():ifnotk.startswith("heads"):model_state_copy[k]=vmodel_state=model_state_copyreturnmodel_state
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.