Shortcuts

Source code for torchtune.modules.position_embeddings

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch
from torch import nn


[docs]class RotaryPositionalEmbeddings(nn.Module): """ This class implements Rotary Positional Embeddings (RoPE) proposed in https://arxiv.org/abs/2104.09864. Reference implementation (used for correctness verfication) can be found here: https://github.com/meta-llama/llama/blob/main/llama/model.py#L80 In this implementation we cache the embeddings for each position upto ``max_seq_len`` by computing this during init. Args: dim (int): Embedding dimension. This is usually set to the dim of each head in the attention module computed as ``embed_dim // num_heads`` max_seq_len (int): Maximum expected sequence length for the model, if exceeded the cached freqs will be recomputed base (int): The base for the geometric progression used to compute the rotation angles """ def __init__( self, dim: int, max_seq_len: int = 4096, base: int = 10_000, ) -> None: super().__init__() self.dim = dim self.base = base self.max_seq_len = max_seq_len self.rope_init() # TODO: delete this once all our recipes are moved off of FSDP1 since we # no longer need to explicitly name our param init method reset_parameters def reset_parameters(self): self.rope_init() def rope_init(self): theta = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) ) self.register_buffer("theta", theta, persistent=False) self.build_rope_cache(self.max_seq_len) def build_rope_cache(self, max_seq_len: int = 4096) -> None: # Create position indexes `[0, 1, ..., max_seq_len - 1]` seq_idx = torch.arange( max_seq_len, dtype=self.theta.dtype, device=self.theta.device ) # Outer product of theta and position index; output tensor has # a shape of [max_seq_len, dim // 2] idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() # cache includes both the cos and sin components and so the output shape is # [max_seq_len, dim // 2, 2] cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) self.register_buffer("cache", cache, persistent=False)
[docs] def forward( self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: x (torch.Tensor): input tensor with shape ``[b, s, n_h, h_d]`` input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b, s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None. Returns: torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]`` Notation used for tensor shapes: - b: batch size - s: sequence length - n_h: num heads - h_d: head dim """ # input tensor has shape [b, s, n_h, h_d] seq_len = x.size(1) # extract the values based on whether input_pos is set or not rope_cache = ( self.cache[:seq_len] if input_pos is None else self.cache[input_pos] ) # reshape input; the last dimension is used for computing the output. # Cast to float to match the reference implementation # tensor has shape [b, s, n_h, h_d // 2, 2] xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # reshape the cache for broadcasting # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, # otherwise has shape [1, s, 1, h_d // 2, 2] rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) # tensor has shape [b, s, n_h, h_d // 2, 2] x_out = torch.stack( [ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], ], -1, ) # tensor has shape [b, s, n_h, h_d] x_out = x_out.flatten(3) return x_out.type_as(x)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources