Source code for torch.distributed.tensor.parallel.multihead_attention_tp
# Copyright (c) Meta Platforms, Inc. and affiliates# pyre-ignore-all-errors[6]importmathfromtypingimportOptional,Unionimporttorchfromtorch.distributed._tensorimportDTensorasDTfromtorch.distributed._tensor.placement_typesimportShardfromtorch.distributed.tensor.parallel._view_with_dim_changeimport(_view_with_sharding_dim_change,)__all__=["TensorParallelMultiheadAttention"]# TODO: Add a test to test equivalence between our Multihead Attention# with other mainstream ones (Megatron-LM or PyTorch).def_stride_same_as_shard(tensor:torch.Tensor,tp_size:int,chunk_dim:int,cat_dim:int)->torch.Tensor:""" Adjust local tensor's stride same as the sharded situation. So that view result will keeps the same. """ifisinstance(tensor,DT):returntensorview_size=list(tensor.size())view_size[chunk_dim]//=tp_sizereturntorch.cat([t.view(*view_size)fortintensor.chunk(tp_size,dim=chunk_dim)],dim=cat_dim,).contiguous()
[docs]classTensorParallelMultiheadAttention(torch.nn.Module):""" Multi-head Attention block from Transformer models. Since we need some customizations for the attention layer, we are writing a customized but mathematically equivalent attention module as defined in torch.nn. Note that: We now only support the case when it's self attention with limited input args and we also assume that the input tensor has a dimension of three. Although we do implement the logic for multihead attention, it was not fully tested. """def__init__(self,embed_dim:int,num_heads:int,dropout:float=0.0,bias:bool=True,add_bias_kv:bool=False,add_zero_attn:bool=False,kdim:Optional[int]=None,vdim:Optional[int]=None,batch_first:bool=False,device:Optional[torch.device]=None,dtype:Optional[torch.dtype]=None,tp_size:int=1,self_attention:bool=True,)->None:super().__init__()self.device:torch.device=(torch.device("cuda"iftorch.cuda.is_available()else"cpu")ifdeviceisNoneelsedevice)self.num_heads=num_headsself.hidden_size=embed_dimself.hidden_size_per_attention_head:int=self.hidden_size//num_headsself.scale:float=self.hidden_size_per_attention_head**-0.5ifself_attention:self.qkv:torch.nn.Module=torch.nn.Linear(embed_dim,embed_dim*3,bias=add_bias_kv,device=self.device)torch.nn.init.xavier_uniform_(self.qkv.weight)ifadd_bias_kv:torch.nn.init.zeros_(self.qkv.bias)else:self.query:torch.nn.Module=torch.nn.Linear(embed_dim,embed_dim,bias=add_bias_kv,device=self.device)self.key:torch.nn.Module=torch.nn.Linear(embed_dim,embed_dim,bias=add_bias_kv,device=self.device)self.value:torch.nn.Module=torch.nn.Linear(embed_dim,embed_dim,bias=add_bias_kv,device=self.device)torch.nn.init.xavier_uniform_(self.query.weight)torch.nn.init.xavier_uniform_(self.key.weight)torch.nn.init.xavier_uniform_(self.value.weight)ifadd_bias_kv:torch.nn.init.zeros_(self.query.bias)torch.nn.init.zeros_(self.key.bias)torch.nn.init.zeros_(self.value.bias)self.proj:torch.nn.Module=torch.nn.Linear(embed_dim,embed_dim,bias=bias,device=self.device)torch.nn.init.kaiming_uniform_(self.proj.weight,a=math.sqrt(5))ifbias:torch.nn.init.zeros_(self.proj.bias)self.tp_size=tp_sizeself.hidden_size=embed_dimself.norm_factor:float=math.sqrt(self.hidden_size_per_attention_head)self.self_attention=self_attentiondefforward(self,query:Union[torch.Tensor,DT],key:Union[torch.Tensor,DT],value:Union[torch.Tensor,DT],key_padding_mask:Optional[torch.Tensor]=None,need_weights:bool=True,attn_mask:Optional[torch.Tensor]=None,average_attn_weights:bool=True,)->Union[torch.Tensor,DT]:b,sq,h=query.shapesk=key.size(1)nh=self.num_headshn=self.hidden_size_per_attention_head# x: [b, sq/sk/sv, h]# ===================# Permute. [sq/sk/sv, b, h]# ===================ifnotself.self_attention:# =====================# Query, Key, and Value# =====================query=query.permute(1,0,2).contiguous()key=key.permute(1,0,2).contiguous()value=value.permute(1,0,2).contiguous()# Attention heads [sq/sk/sv, b, h] --> [sq/sk/sv * b, (nh * hn)]query=query.view(-1,h)key=key.view(-1,h)value=value.view(-1,h)query_layer=_view_with_sharding_dim_change(self.query(query),1,(sq,b*nh,hn))key_layer=_view_with_sharding_dim_change(self.key(key),1,(sk,b*nh,hn))value_layer=_view_with_sharding_dim_change(self.value(value),1,(sk,b*nh,hn))else:asserttorch.equal(query,key)andtorch.equal(query,value),"inputs are different for self-attention."# =====================# Query# =====================query=query.permute(1,0,2).contiguous()# Attention heads [sq, b, h] --> [sq * b, (nh * 3 * hn)]query=query.view(-1,h)mixed_x_layer=self.qkv(query)# [sq * b, 3 * h] --> [sq, b, nh, 3 * hn]mixed_x_layer=_view_with_sharding_dim_change(mixed_x_layer,2,(sq,b,nh,3*hn))# [sq, b, nh, 3 * hn] --> 3 [sq, b, nh, hn]last_dim=mixed_x_layer.dim()-1last_dim_size=mixed_x_layer.size(last_dim)//3(query_layer,key_layer,value_layer)=mixed_x_layer.split(last_dim_size,dim=last_dim)query_layer=_stride_same_as_shard(query_layer,self.tp_size,2,1)key_layer=_stride_same_as_shard(key_layer,self.tp_size,2,1)value_layer=_stride_same_as_shard(value_layer,self.tp_size,2,1)# [sq, b, nh, hn] -> [sq, b * nh, hn]query_layer=_view_with_sharding_dim_change(query_layer,1,(sq,b*nh,-1))key_layer=_view_with_sharding_dim_change(key_layer,1,(sq,b*nh,-1))value_layer=_view_with_sharding_dim_change(value_layer,1,(sq,b*nh,-1))# ===================================# Raw attention scores. [b, nh, s, s]# ===================================factor=self.tp_sizeifisinstance(query_layer,DT)else1# preallocting result tensor: [b * nh, sq, sk]matmul_result=torch.empty(b*nh//factor,sq,sk,dtype=query_layer.dtype,device=self.device,)ifisinstance(query_layer,DT):matmul_result=DT.from_local(matmul_result,query_layer.device_mesh,[Shard(0)],run_check=False,)# Raw attention scores. [b * nh, sq, sk]attn=torch.baddbmm(matmul_result,query_layer.transpose(0,1),# [b * nh, sq, hn]key_layer.transpose(0,1).transpose(1,2),# [b * nh, hn, sk]beta=0.0,alpha=(1.0/self.norm_factor),)# ===============# Attention probs# ===============attn=attn.softmax(dim=-1)# =========================# Context layer. [sq * b, hidden]# =========================# bmm: [b * nh, sq, hn]context_layer=torch.bmm(attn,value_layer.transpose(0,1))# change view [nh, b, sq, hn]context_layer=context_layer.view(nh,b,sq,hn)# [nh, b, sq, hn] --> [sq, b, nh, hn]context_layer=context_layer.permute(2,1,0,3).contiguous()# [sq, b, nh, hn] --> [sq * b, hidden]context_layer=_view_with_sharding_dim_change(context_layer.contiguous(),1,(-1,self.hidden_size))# =================# Projection. [sq, b, h]# =================output=self.proj(context_layer).view(sq,b,h)# ===================# Permute. [b, sq, h]# ===================output=output.permute(1,0,2)returnoutputdefcopy(self,that:torch.nn.MultiheadAttention)->None:# TODO: current implementation assume `self` is a self attention moduleassert(self.hidden_size==that.embed_dim),"embed_dim must be equal in TensorParallelMultiheadAttention.copy()!"ifthat.in_proj_weightisnotNone:self.qkv.register_parameter("weight",that.in_proj_weight)ifthat.in_proj_biasisnotNone:self.qkv.register_parameter("bias",that.in_proj_bias)ifthat.out_proj.weightisnotNone:# TODO: The use of Parameter is to avoid `mypy` issue caused# by the `tensor` type annotation on Linear.weight to which# a Parameter object is actually assignedself.proj.register_parameter("weight",torch.nn.Parameter(that.out_proj.weight))ifthat.out_proj.biasisnotNone:self.proj.register_parameter("bias",that.out_proj.bias)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.