# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.importtorchfromtorchimportnn,Tensor
[docs]classRMSNorm(nn.Module):""" Implements Root Mean Square Normalization introduced in https://arxiv.org/abs/1910.07467. Reference implementation (used for correctness verfication) can be found here: https://github.com/facebookresearch/llama/blob/main/llama/model.py Args: dim (int): embedding size eps (float): small value to avoid division by zero. Default: 1e-6 """def__init__(self,dim:int,eps:float=1e-6)->None:super().__init__()self.eps=epsself.scale=nn.Parameter(torch.ones(dim))
[docs]defforward(self,x:Tensor)->Tensor:""" Args: x (Tensor): input tensor to normalize Returns: Tensor: The output tensor after applying RMSNorm. """# computation is in fp32x_fp32=x.float()x_normed=(x_fp32*torch.rsqrt(x_fp32.pow(2).mean(-1,keepdim=True)+self.eps)).type_as(x)returnx_normed*self.scale
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.