RMSNorm¶
- class torchtune.modules.RMSNorm(dim: int, eps: float = 1e-06)[source]¶
Root Mean Square Normalization in fp32.
See: https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html
- Parameters:
- forward(x: Tensor) Tensor [source]¶
- Parameters:
x (torch.Tensor) – input tensor to normalize
- Returns:
The normalized and scaled tensor having the same shape as
x
.- Return type: