Source code for torchtune.modules.position_embeddings
# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.fromtypingimportOptionalimporttorchfromtorchimportnn,Tensor
[docs]classRotaryPositionalEmbeddings(nn.Module):""" This class implements Rotary Positional Embeddings (RoPE) proposed in https://arxiv.org/abs/2104.09864. Reference implementation (used for correctness verfication) can be found here: https://github.com/meta-llama/llama/blob/main/llama/model.py#L80 In this implementation we cache the embeddings for each position upto ``max_seq_len`` by computing this during init. Args: dim (int): Embedding dimension. This is usually set to the dim of each head in the attention module computed as ````embed_dim`` // ``num_heads```` max_seq_len (int): Maximum expected sequence length for the model, if exceeded the cached freqs will be recomputed base (int): The base for the geometric progression used to compute the rotation angles """def__init__(self,dim:int,max_seq_len:int=4096,base:int=10_000,)->None:super().__init__()self.dim=dimself.base=baseself.max_seq_len=max_seq_lenself._rope_init()# We need to explicitly define reset_parameters for FSDP initialization, see# https://github.com/pytorch/pytorch/blob/797d4fbdf423dd9320ebe383fb57ffb1135c4a99/torch/distributed/fsdp/_init_utils.py#L885defreset_parameters(self):self._rope_init()def_rope_init(self):theta=1.0/(self.base**(torch.arange(0,self.dim,2)[:(self.dim//2)].float()/self.dim))self.register_buffer("theta",theta,persistent=False)self.build_rope_cache(self.max_seq_len)defbuild_rope_cache(self,max_seq_len:int=4096)->None:# Create position indexes `[0, 1, ..., max_seq_len - 1]`seq_idx=torch.arange(max_seq_len,dtype=self.theta.dtype,device=self.theta.device)# Outer product of theta and position index; output tensor has# a shape of [max_seq_len, dim // 2]idx_theta=torch.einsum("i, j -> ij",seq_idx,self.theta).float()# cache includes both the cos and sin components and so the output shape is# [max_seq_len, dim // 2, 2]cache=torch.stack([torch.cos(idx_theta),torch.sin(idx_theta)],dim=-1)self.register_buffer("cache",cache,persistent=False)
[docs]defforward(self,x:Tensor,*,input_pos:Optional[Tensor]=None)->Tensor:""" Args: x (Tensor): input tensor with shape [b, s, n_h, h_d] input_pos (Optional[Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b, s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None. Returns: Tensor: output tensor with RoPE applied Notation used for tensor shapes: - b: batch size - s: sequence length - n_h: num heads - h_d: head dim TODO: The implementation below can be made more efficient for inference. """# input tensor has shape [b, s, n_h, h_d]seq_len=x.size(1)# extract the values based on whether input_pos is set or notrope_cache=(self.cache[:seq_len]ifinput_posisNoneelseself.cache[input_pos])# reshape input; the last dimension is used for computing the output.# Cast to float to match the reference implementation# tensor has shape [b, s, n_h, h_d // 2, 2]xshaped=x.float().reshape(*x.shape[:-1],-1,2)# reshape the cache for broadcasting# tensor has shape [b, s, 1, h_d // 2, 2] if packed samples,# otherwise has shape [1, s, 1, h_d // 2, 2]rope_cache=rope_cache.view(-1,xshaped.size(1),1,xshaped.size(3),2)# tensor has shape [b, s, n_h, h_d // 2, 2]x_out=torch.stack([xshaped[...,0]*rope_cache[...,0]-xshaped[...,1]*rope_cache[...,1],xshaped[...,1]*rope_cache[...,0]+xshaped[...,0]*rope_cache[...,1],],-1,)# tensor has shape [b, s, n_h, h_d]x_out=x_out.flatten(3)returnx_out.type_as(x)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.