Source code for torchvision.models.vision_transformer
importmathfromcollectionsimportOrderedDictfromfunctoolsimportpartialfromtypingimportAny,Callable,List,NamedTuple,Optionalimporttorchimporttorch.nnasnnfrom.._internally_replaced_utilsimportload_state_dict_from_urlfrom..ops.miscimportConvNormActivationfrom..utilsimport_log_api_usage_once__all__=["VisionTransformer","vit_b_16","vit_b_32","vit_l_16","vit_l_32",]model_urls={"vit_b_16":"https://download.pytorch.org/models/vit_b_16-c867db91.pth","vit_b_32":"https://download.pytorch.org/models/vit_b_32-d86f8d99.pth","vit_l_16":"https://download.pytorch.org/models/vit_l_16-852ce7e3.pth","vit_l_32":"https://download.pytorch.org/models/vit_l_32-c7638314.pth",}classConvStemConfig(NamedTuple):out_channels:intkernel_size:intstride:intnorm_layer:Callable[...,nn.Module]=nn.BatchNorm2dactivation_layer:Callable[...,nn.Module]=nn.ReLUclassMLPBlock(nn.Sequential):"""Transformer MLP block."""def__init__(self,in_dim:int,mlp_dim:int,dropout:float):super().__init__()self.linear_1=nn.Linear(in_dim,mlp_dim)self.act=nn.GELU()self.dropout_1=nn.Dropout(dropout)self.linear_2=nn.Linear(mlp_dim,in_dim)self.dropout_2=nn.Dropout(dropout)nn.init.xavier_uniform_(self.linear_1.weight)nn.init.xavier_uniform_(self.linear_2.weight)nn.init.normal_(self.linear_1.bias,std=1e-6)nn.init.normal_(self.linear_2.bias,std=1e-6)classEncoderBlock(nn.Module):"""Transformer encoder block."""def__init__(self,num_heads:int,hidden_dim:int,mlp_dim:int,dropout:float,attention_dropout:float,norm_layer:Callable[...,torch.nn.Module]=partial(nn.LayerNorm,eps=1e-6),):super().__init__()self.num_heads=num_heads# Attention blockself.ln_1=norm_layer(hidden_dim)self.self_attention=nn.MultiheadAttention(hidden_dim,num_heads,dropout=attention_dropout,batch_first=True)self.dropout=nn.Dropout(dropout)# MLP blockself.ln_2=norm_layer(hidden_dim)self.mlp=MLPBlock(hidden_dim,mlp_dim,dropout)defforward(self,input:torch.Tensor):torch._assert(input.dim()==3,f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")x=self.ln_1(input)x,_=self.self_attention(query=x,key=x,value=x,need_weights=False)x=self.dropout(x)x=x+inputy=self.ln_2(x)y=self.mlp(y)returnx+yclassEncoder(nn.Module):"""Transformer Model Encoder for sequence to sequence translation."""def__init__(self,seq_length:int,num_layers:int,num_heads:int,hidden_dim:int,mlp_dim:int,dropout:float,attention_dropout:float,norm_layer:Callable[...,torch.nn.Module]=partial(nn.LayerNorm,eps=1e-6),):super().__init__()# Note that batch_size is on the first dim because# we have batch_first=True in nn.MultiAttention() by defaultself.pos_embedding=nn.Parameter(torch.empty(1,seq_length,hidden_dim).normal_(std=0.02))# from BERTself.dropout=nn.Dropout(dropout)layers:OrderedDict[str,nn.Module]=OrderedDict()foriinrange(num_layers):layers[f"encoder_layer_{i}"]=EncoderBlock(num_heads,hidden_dim,mlp_dim,dropout,attention_dropout,norm_layer,)self.layers=nn.Sequential(layers)self.ln=norm_layer(hidden_dim)defforward(self,input:torch.Tensor):torch._assert(input.dim()==3,f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")input=input+self.pos_embeddingreturnself.ln(self.layers(self.dropout(input)))classVisionTransformer(nn.Module):"""Vision Transformer as per https://arxiv.org/abs/2010.11929."""def__init__(self,image_size:int,patch_size:int,num_layers:int,num_heads:int,hidden_dim:int,mlp_dim:int,dropout:float=0.0,attention_dropout:float=0.0,num_classes:int=1000,representation_size:Optional[int]=None,norm_layer:Callable[...,torch.nn.Module]=partial(nn.LayerNorm,eps=1e-6),conv_stem_configs:Optional[List[ConvStemConfig]]=None,):super().__init__()_log_api_usage_once(self)torch._assert(image_size%patch_size==0,"Input shape indivisible by patch size!")self.image_size=image_sizeself.patch_size=patch_sizeself.hidden_dim=hidden_dimself.mlp_dim=mlp_dimself.attention_dropout=attention_dropoutself.dropout=dropoutself.num_classes=num_classesself.representation_size=representation_sizeself.norm_layer=norm_layerifconv_stem_configsisnotNone:# As per https://arxiv.org/abs/2106.14881seq_proj=nn.Sequential()prev_channels=3fori,conv_stem_layer_configinenumerate(conv_stem_configs):seq_proj.add_module(f"conv_bn_relu_{i}",ConvNormActivation(in_channels=prev_channels,out_channels=conv_stem_layer_config.out_channels,kernel_size=conv_stem_layer_config.kernel_size,stride=conv_stem_layer_config.stride,norm_layer=conv_stem_layer_config.norm_layer,activation_layer=conv_stem_layer_config.activation_layer,),)prev_channels=conv_stem_layer_config.out_channelsseq_proj.add_module("conv_last",nn.Conv2d(in_channels=prev_channels,out_channels=hidden_dim,kernel_size=1))self.conv_proj:nn.Module=seq_projelse:self.conv_proj=nn.Conv2d(in_channels=3,out_channels=hidden_dim,kernel_size=patch_size,stride=patch_size)seq_length=(image_size//patch_size)**2# Add a class tokenself.class_token=nn.Parameter(torch.zeros(1,1,hidden_dim))seq_length+=1self.encoder=Encoder(seq_length,num_layers,num_heads,hidden_dim,mlp_dim,dropout,attention_dropout,norm_layer,)self.seq_length=seq_lengthheads_layers:OrderedDict[str,nn.Module]=OrderedDict()ifrepresentation_sizeisNone:heads_layers["head"]=nn.Linear(hidden_dim,num_classes)else:heads_layers["pre_logits"]=nn.Linear(hidden_dim,representation_size)heads_layers["act"]=nn.Tanh()heads_layers["head"]=nn.Linear(representation_size,num_classes)self.heads=nn.Sequential(heads_layers)ifisinstance(self.conv_proj,nn.Conv2d):# Init the patchify stemfan_in=self.conv_proj.in_channels*self.conv_proj.kernel_size[0]*self.conv_proj.kernel_size[1]nn.init.trunc_normal_(self.conv_proj.weight,std=math.sqrt(1/fan_in))ifself.conv_proj.biasisnotNone:nn.init.zeros_(self.conv_proj.bias)elifself.conv_proj.conv_lastisnotNoneandisinstance(self.conv_proj.conv_last,nn.Conv2d):# Init the last 1x1 conv of the conv stemnn.init.normal_(self.conv_proj.conv_last.weight,mean=0.0,std=math.sqrt(2.0/self.conv_proj.conv_last.out_channels))ifself.conv_proj.conv_last.biasisnotNone:nn.init.zeros_(self.conv_proj.conv_last.bias)ifhasattr(self.heads,"pre_logits")andisinstance(self.heads.pre_logits,nn.Linear):fan_in=self.heads.pre_logits.in_featuresnn.init.trunc_normal_(self.heads.pre_logits.weight,std=math.sqrt(1/fan_in))nn.init.zeros_(self.heads.pre_logits.bias)ifisinstance(self.heads.head,nn.Linear):nn.init.zeros_(self.heads.head.weight)nn.init.zeros_(self.heads.head.bias)def_process_input(self,x:torch.Tensor)->torch.Tensor:n,c,h,w=x.shapep=self.patch_sizetorch._assert(h==self.image_size,"Wrong image height!")torch._assert(w==self.image_size,"Wrong image width!")n_h=h//pn_w=w//p# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)x=self.conv_proj(x)# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))x=x.reshape(n,self.hidden_dim,n_h*n_w)# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)# The self attention layer expects inputs in the format (N, S, E)# where S is the source sequence length, N is the batch size, E is the# embedding dimensionx=x.permute(0,2,1)returnxdefforward(self,x:torch.Tensor):# Reshape and permute the input tensorx=self._process_input(x)n=x.shape[0]# Expand the class token to the full batchbatch_class_token=self.class_token.expand(n,-1,-1)x=torch.cat([batch_class_token,x],dim=1)x=self.encoder(x)# Classifier "token" as used by standard language architecturesx=x[:,0]x=self.heads(x)returnxdef_vision_transformer(arch:str,patch_size:int,num_layers:int,num_heads:int,hidden_dim:int,mlp_dim:int,pretrained:bool,progress:bool,**kwargs:Any,)->VisionTransformer:image_size=kwargs.pop("image_size",224)model=VisionTransformer(image_size=image_size,patch_size=patch_size,num_layers=num_layers,num_heads=num_heads,hidden_dim=hidden_dim,mlp_dim=mlp_dim,**kwargs,)ifpretrained:ifarchnotinmodel_urls:raiseValueError(f"No checkpoint is available for model type '{arch}'!")state_dict=load_state_dict_from_url(model_urls[arch],progress=progress)model.load_state_dict(state_dict)returnmodel
[docs]defvit_b_16(pretrained:bool=False,progress:bool=True,**kwargs:Any)->VisionTransformer:""" Constructs a vit_b_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_vision_transformer(arch="vit_b_16",patch_size=16,num_layers=12,num_heads=12,hidden_dim=768,mlp_dim=3072,pretrained=pretrained,progress=progress,**kwargs,)
[docs]defvit_b_32(pretrained:bool=False,progress:bool=True,**kwargs:Any)->VisionTransformer:""" Constructs a vit_b_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_vision_transformer(arch="vit_b_32",patch_size=32,num_layers=12,num_heads=12,hidden_dim=768,mlp_dim=3072,pretrained=pretrained,progress=progress,**kwargs,)
[docs]defvit_l_16(pretrained:bool=False,progress:bool=True,**kwargs:Any)->VisionTransformer:""" Constructs a vit_l_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_vision_transformer(arch="vit_l_16",patch_size=16,num_layers=24,num_heads=16,hidden_dim=1024,mlp_dim=4096,pretrained=pretrained,progress=progress,**kwargs,)
[docs]defvit_l_32(pretrained:bool=False,progress:bool=True,**kwargs:Any)->VisionTransformer:""" Constructs a vit_l_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_vision_transformer(arch="vit_l_32",patch_size=32,num_layers=24,num_heads=16,hidden_dim=1024,mlp_dim=4096,pretrained=pretrained,progress=progress,**kwargs,)
definterpolate_embeddings(image_size:int,patch_size:int,model_state:"OrderedDict[str, torch.Tensor]",interpolation_mode:str="bicubic",reset_heads:bool=False,)->"OrderedDict[str, torch.Tensor]":"""This function helps interpolating positional embeddings during checkpoint loading, especially when you want to apply a pre-trained model on images with different resolution. Args: image_size (int): Image size of the new model. patch_size (int): Patch size of the new model. model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. reset_heads (bool): If true, not copying the state of heads. Default: False. Returns: OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. """# Shape of pos_embedding is (1, seq_length, hidden_dim)pos_embedding=model_state["encoder.pos_embedding"]n,seq_length,hidden_dim=pos_embedding.shapeifn!=1:raiseValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")new_seq_length=(image_size//patch_size)**2+1# Need to interpolate the weights for the position embedding.# We do this by reshaping the positions embeddings to a 2d grid, performing# an interpolation in the (h, w) space and then reshaping back to a 1d grid.ifnew_seq_length!=seq_length:# The class token embedding shouldn't be interpolated so we split it up.seq_length-=1new_seq_length-=1pos_embedding_token=pos_embedding[:,:1,:]pos_embedding_img=pos_embedding[:,1:,:]# (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)pos_embedding_img=pos_embedding_img.permute(0,2,1)seq_length_1d=int(math.sqrt(seq_length))torch._assert(seq_length_1d*seq_length_1d==seq_length,"seq_length is not a perfect square!")# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)pos_embedding_img=pos_embedding_img.reshape(1,hidden_dim,seq_length_1d,seq_length_1d)new_seq_length_1d=image_size//patch_size# Perform interpolation.# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)new_pos_embedding_img=nn.functional.interpolate(pos_embedding_img,size=new_seq_length_1d,mode=interpolation_mode,align_corners=True,)# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)new_pos_embedding_img=new_pos_embedding_img.reshape(1,hidden_dim,new_seq_length)# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)new_pos_embedding_img=new_pos_embedding_img.permute(0,2,1)new_pos_embedding=torch.cat([pos_embedding_token,new_pos_embedding_img],dim=1)model_state["encoder.pos_embedding"]=new_pos_embeddingifreset_heads:model_state_copy:"OrderedDict[str, torch.Tensor]"=OrderedDict()fork,vinmodel_state.items():ifnotk.startswith("heads"):model_state_copy[k]=vmodel_state=model_state_copyreturnmodel_state
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.