Source code for torchtune.modules.position_embeddings
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from torch import nn
[docs]class RotaryPositionalEmbeddings(nn.Module):
"""
This class implements Rotary Positional Embeddings (RoPE)
proposed in https://arxiv.org/abs/2104.09864.
Reference implementation (used for correctness verfication)
can be found here:
https://github.com/meta-llama/llama/blob/main/llama/model.py#L80
In this implementation we cache the embeddings for each position upto
``max_seq_len`` by computing this during init.
Args:
dim (int): Embedding dimension. This is usually set to the dim of each
head in the attention module computed as ``embed_dim // num_heads``
max_seq_len (int): Maximum expected sequence length for the
model, if exceeded the cached freqs will be recomputed
base (int): The base for the geometric progression used to compute
the rotation angles
"""
def __init__(
self,
dim: int,
max_seq_len: int = 4096,
base: int = 10_000,
) -> None:
super().__init__()
self.dim = dim
self.base = base
self.max_seq_len = max_seq_len
self.rope_init()
# TODO: delete this once all our recipes are moved off of FSDP1 since we
# no longer need to explicitly name our param init method reset_parameters
def reset_parameters(self):
self.rope_init()
def rope_init(self):
theta = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim)
)
self.register_buffer("theta", theta, persistent=False)
self.build_rope_cache(self.max_seq_len)
def build_rope_cache(self, max_seq_len: int = 4096) -> None:
# Create position indexes `[0, 1, ..., max_seq_len - 1]`
seq_idx = torch.arange(
max_seq_len, dtype=self.theta.dtype, device=self.theta.device
)
# Outer product of theta and position index; output tensor has
# a shape of [max_seq_len, dim // 2]
idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()
# cache includes both the cos and sin components and so the output shape is
# [max_seq_len, dim // 2, 2]
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
self.register_buffer("cache", cache, persistent=False)
[docs] def forward(
self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x (torch.Tensor): input tensor with shape
``[b, s, n_h, h_d]``
input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b, s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
Returns:
torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]``
Notation used for tensor shapes:
- b: batch size
- s: sequence length
- n_h: num heads
- h_d: head dim
"""
# input tensor has shape [b, s, n_h, h_d]
seq_len = x.size(1)
# extract the values based on whether input_pos is set or not
rope_cache = (
self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
)
# reshape input; the last dimension is used for computing the output.
# Cast to float to match the reference implementation
# tensor has shape [b, s, n_h, h_d // 2, 2]
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
# reshape the cache for broadcasting
# tensor has shape [b, s, 1, h_d // 2, 2] if packed samples,
# otherwise has shape [1, s, 1, h_d // 2, 2]
rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
# tensor has shape [b, s, n_h, h_d // 2, 2]
x_out = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0]
- xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0]
+ xshaped[..., 0] * rope_cache[..., 1],
],
-1,
)
# tensor has shape [b, s, n_h, h_d]
x_out = x_out.flatten(3)
return x_out.type_as(x)