Shortcuts

TransformerDecoder

class torchtune.modules.TransformerDecoder(tok_embeddings: Embedding, layer: TransformerDecoderLayer, num_layers: int, max_seq_len: int, num_heads: int, head_dim: int, norm: Module, output: Linear)[source]

Transformer Decoder derived from the Llama2 architecture.

Parameters:
  • tok_embeddings (nn.Embedding) – PyTorch embedding layer, to be used to move tokens to an embedding space.

  • layer (TransformerDecoderLayer) – Transformer Decoder layer.

  • num_layers (int) – Number of Transformer Decoder layers.

  • max_seq_len (int) – maximum sequence length the model will be run with, as used by KVCache()

  • num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value. This is used to setup the KVCache()

  • head_dim (int) – embedding dimension for each head in self-attention. This is used to setup the KVCache()

  • norm (nn.Module) – Callable that applies normalization to the output of the decoder, before final MLP.

  • output (nn.Linear) – Callable that applies a linear transformation to the output of the decoder.

Note

Arg values are checked for correctness (eg: attn_dropout belongs to [0,1]) in the module where they are used. This helps reduces the number of raise statements in code and improves readability.

caches_are_enabled() bool[source]

Check if the key value caches are setup.

forward(tokens: Tensor, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Tensor[source]
Parameters:
  • tokens (Tensor) – input tensor with shape [b x s]

  • mask (Optional[Tensor]) – Optional boolean tensor which contains the attention mask with shape [b x s x s]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None.

  • input_pos (Optional[Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None.

Note: At the very first step of inference, when the model is provided with a prompt, input_pos would contain the positions of all of the tokens in the prompt (eg: torch.arange(prompt_length)). This is because we will need to compute the KV values for each position.

Returns:

output tensor with shape [b x s x v]

Return type:

Tensor

Raises:

ValueError – if causal_mask is set but input_pos is None

Notation used for tensor shapes:
  • b: batch size

  • s: sequence length

  • v: vocab size

  • d: embed dim

  • m_s: max seq len

reset_caches()[source]

Reset the key value caches.

setup_caches(batch_size: int, dtype: dtype) None[source]

Setup key value caches for attention calculation.

Parameters:
  • batch_size (int) – batch size for the caches.

  • dtype (torch.dpython:type) – dtype for the caches.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources