FeedForward¶
- class torchtune.modules.FeedForward(*, gate_proj: Module, down_proj: Module, up_proj: Optional[Module] = None, activation: Module = SiLU())[source]¶
This class implements the feed-forward network derived from Llama2.
- Parameters:
gate_proj (nn.Module) – Projection from input dim to hidden dim, fed through activation and multiplied by up_proj.
down_proj (nn.Module) – Final projection to output dim.
up_proj (Optional[nn.Module]) – Projection from input dim to hidden dim, multiplied by activation(gate_proj).
activation (nn.Module) – Activation function to use. Default is nn.SiLU().
- forward(x: Tensor) Tensor [source]¶
- Parameters:
x (torch.Tensor) – input tensor with shape
(..., in_dim)
, wherein_dim
is the input dimension of bothgate_proj
andup_proj
.- Returns:
output tensor with shape
(..., out_dim)
, whereout_dim
is the output dimension ofdown_proj
.- Return type: