Shortcuts

FeedForward

class torchtune.modules.FeedForward(*, gate_proj: Module, down_proj: Module, up_proj: Optional[Module] = None, activation: Module = SiLU())[source]

This class implements the feed-forward network derived from Llama2.

Parameters:
  • gate_proj (nn.Module) – Projection from input dim to hidden dim, fed through activation and multiplied by up_proj.

  • down_proj (nn.Module) – Final projection to output dim.

  • up_proj (Optional[nn.Module]) – Projection from input dim to hidden dim, multiplied by activation(gate_proj).

  • activation (nn.Module) – Activation function to use. Default is nn.SiLU().

forward(x: Tensor) Tensor[source]
Parameters:

x (torch.Tensor) – input tensor with shape (..., in_dim), where in_dim is the input dimension of both gate_proj and up_proj.

Returns:

output tensor with shape (..., out_dim), where out_dim is the output dimension of down_proj.

Return type:

torch.Tensor

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources