Shortcuts

RMSNorm

class torchtune.modules.RMSNorm(dim: int, eps: float = 1e-06)[source]

Implements Root Mean Square Normalization introduced in https://arxiv.org/abs/1910.07467.

Reference implementation (used for correctness verfication) can be found here: https://github.com/facebookresearch/llama/blob/main/llama/model.py

Parameters:
  • dim (int) – embedding size

  • eps (float) – small value to avoid division by zero. Default: 1e-6

forward(x: Tensor) Tensor[source]
Parameters:

x (Tensor) – input tensor to normalize

Returns:

The output tensor after applying RMSNorm.

Return type:

Tensor

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources