Shortcuts

Source code for torchtune.modules.rms_norm

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from torch import nn, Tensor


[docs]class RMSNorm(nn.Module): """ Implements Root Mean Square Normalization introduced in https://arxiv.org/abs/1910.07467. Reference implementation (used for correctness verfication) can be found here: https://github.com/facebookresearch/llama/blob/main/llama/model.py Args: dim (int): embedding size eps (float): small value to avoid division by zero. Default: 1e-6 """ def __init__(self, dim: int, eps: float = 1e-6) -> None: super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim))
[docs] def forward(self, x: Tensor) -> Tensor: """ Args: x (Tensor): input tensor to normalize Returns: Tensor: The output tensor after applying RMSNorm. """ # computation is in fp32 x_fp32 = x.float() x_normed = ( x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) ).type_as(x) return x_normed * self.scale

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources