TransformerDecoder¶
- class torchtune.modules.TransformerDecoder(tok_embeddings: Embedding, layer: TransformerDecoderLayer, num_layers: int, max_seq_len: int, num_heads: int, head_dim: int, norm: Module, output: Linear)[source]¶
Transformer Decoder derived from the Llama2 architecture.
- Parameters:
tok_embeddings (nn.Embedding) – PyTorch embedding layer, to be used to move tokens to an embedding space.
layer (TransformerDecoderLayer) – Transformer Decoder layer.
num_layers (int) – Number of Transformer Decoder layers.
max_seq_len (int) – maximum sequence length the model will be run with, as used by
KVCache()
num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value. This is used to setup the
KVCache()
head_dim (int) – embedding dimension for each head in self-attention. This is used to setup the
KVCache()
norm (nn.Module) – Callable that applies normalization to the output of the decoder, before final MLP.
output (nn.Linear) – Callable that applies a linear transformation to the output of the decoder.
Note
Arg values are checked for correctness (eg:
attn_dropout
belongs to [0,1]) in the module where they are used. This helps reduces the number of raise statements in code and improves readability.- forward(tokens: Tensor, input_pos: Optional[Tensor] = None) Tensor [source]¶
- Parameters:
tokens (Tensor) – input tensor with shape [b x s]
input_pos (Optional[Tensor]) – Optional tensor which contains the position of the current token. This is only used during inference. Default is None
Note: At the very first step of inference, when the model is provided with a prompt,
input_pos
would contain the positions of all of the tokens in the prompt (eg:torch.arange(prompt_length)
). This is because we will need to compute the KV values for each position.- Returns:
output tensor with shape [b x s x v]
- Return type:
Tensor
- Raises:
ValueError – if causal_mask is set but input_pos is None
- Notation used for tensor shapes:
b: batch size
s: sequence length
v: vocab size
d: embed dim
m_s: max seq len