TransformerDecoderLayer¶
- class torchtune.modules.TransformerDecoderLayer(attn: CausalSelfAttention, mlp: Module, sa_norm: Module, mlp_norm: Module)[source]¶
Transformer layer derived from the Llama2 model. Normalization is applied before the attention and FF layer.
- Parameters:
attn (CausalSelfAttention) – Attention module.
mlp (nn.Module) – Feed-forward module.
sa_norm (nn.Module) – Normalization to be applied before self-attention.
mlp_norm (nn.Module) – Normalization to be applied before the feed-forward layer.
- forward(x: Tensor, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Tensor [source]¶
- Parameters:
x (Tensor) – input tensor with shape [batch_size x seq_length x embed_dim]
mask (Optional[Tensor]) – Optional tensor which contains the mask. Only used during inference. Default is None.
input_pos (Optional[Tensor]) – Optional tensor which contains the position of the current token. This is only used during inference. Default is None
- Returns:
- output tensor with same shape as input
[batch_size x seq_length x embed_dim]
- Return type:
Tensor
- Notation used for tensor shapes:
b: batch size
s: sequence length
d: embed dim
Todo
Make position of norm configurable