Shortcuts

TransformerDecoder

class torchtune.modules.TransformerDecoder(tok_embeddings: Embedding, layer: TransformerDecoderLayer, num_layers: int, max_seq_len: int, num_heads: int, head_dim: int, norm: Module, output: Linear)[source]

Transformer Decoder derived from the Llama2 architecture.

Parameters:
  • tok_embeddings (nn.Embedding) – PyTorch embedding layer, to be used to move tokens to an embedding space.

  • layer (TransformerDecoderLayer) – Transformer Decoder layer.

  • num_layers (int) – Number of Transformer Decoder layers.

  • max_seq_len (int) – maximum sequence length the model will be run with, as used by KVCache()

  • num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value. This is used to setup the KVCache()

  • head_dim (int) – embedding dimension for each head in self-attention. This is used to setup the KVCache()

  • norm (nn.Module) – Callable that applies normalization to the output of the decoder, before final MLP.

  • output (nn.Linear) – Callable that applies a linear transformation to the output of the decoder.

Note

Arg values are checked for correctness (eg: attn_dropout belongs to [0,1]) in the module where they are used. This helps reduces the number of raise statements in code and improves readability.

forward(tokens: Tensor, input_pos: Optional[Tensor] = None) Tensor[source]
Parameters:
  • tokens (Tensor) – input tensor with shape [b x s]

  • input_pos (Optional[Tensor]) – Optional tensor which contains the position of the current token. This is only used during inference. Default is None

Note: At the very first step of inference, when the model is provided with a prompt, input_pos would contain the positions of all of the tokens in the prompt (eg: torch.arange(prompt_length)). This is because we will need to compute the KV values for each position.

Returns:

output tensor with shape [b x s x v]

Return type:

Tensor

Raises:

ValueError – if causal_mask is set but input_pos is None

Notation used for tensor shapes:
  • b: batch size

  • s: sequence length

  • v: vocab size

  • d: embed dim

  • m_s: max seq len

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources