Shortcuts

TransformerDecoder

class torchtune.modules.TransformerDecoder(*, tok_embeddings: Embedding, layers: Union[Module, List[Module], ModuleList], max_seq_len: int, num_heads: int, head_dim: int, norm: Module, output: Union[Linear, Callable], num_layers: Optional[int] = None, output_hidden_states: Optional[List[int]] = None)[source]

Transformer Decoder derived from the Llama2 architecture.

Parameters:
  • tok_embeddings (nn.Embedding) – PyTorch embedding layer, to be used to move tokens to an embedding space.

  • layers (Union[nn.Module, List[nn.Module], nn.ModuleList]) – A single transformer Decoder layer, an nn.ModuleList of layers or a list of layers. It is recommended to use an nn.ModuleList.

  • max_seq_len (int) – maximum sequence length the model will be run with, as used by KVCache()

  • num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value. This is used to setup the KVCache()

  • head_dim (int) – embedding dimension for each head in self-attention. This is used to setup the KVCache()

  • norm (nn.Module) – Callable that applies normalization to the output of the decoder, before final MLP.

  • output (Union[nn.Linear, Callable]) – Callable that applies a linear transformation to the output of the decoder.

  • num_layers (Optional[int]) – Number of Transformer Decoder layers, only define when layers is not a list.

  • output_hidden_states (Optional[List[int]]) – List of layers (indices) to include in the output

Raises:

Note

Arg values are checked for correctness (eg: attn_dropout belongs to [0,1]) in the module where they are used. This helps reduces the number of raise statements in code and improves readability.

caches_are_enabled() bool[source]

Check if the key value caches are setup. This is useful to efficient inference.

chunked_output(last_hidden_state: Tensor) List[Tensor][source]

Apply output projection in chunks. This should be applied in conjunction with CEWithChunkedOutputLoss as upcasting to fp32 is done there.

To use this method, you should first call set_num_output_chunks().

Parameters:

last_hidden_state (torch.Tensor) – last hidden state of the decoder, having shape [b, seq_len, embed_dim].

Returns:

List of num_chunks output tensors, each with shape

[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.

Return type:

List[torch.Tensor]

forward(tokens: Tensor, *, mask: Optional[Tensor] = None, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Union[Tensor, List[Tensor]][source]
Parameters:
  • tokens (torch.Tensor) – input tensor with shape [b x s]

  • mask (Optional[_MaskType]) –

    Used to mask the scores after the query-key multiplication and before the softmax. This parameter is required during inference if caches have been setup. Either:

    A boolean tensor with shape [b x s x s], [b x s x self.encoder_max_cache_seq_len], or [b x s x self.encoder_max_cache_seq_len] if using KV-cacheing with encoder/decoder layers. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default.

    A BlockMask for document masking in a packed sequence created via create_block_mask. We use flex_attention() when computing attention with block masks. Default is None.

  • encoder_input (Optional[torch.Tensor]) – Optional input embeds from the encoder. Shape [b x s_e x d_e]

  • encoder_mask (Optional[torch.Tensor]) – Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape [b x s x s_e]. Default is None, but this is required during inference if the model has been setup with any layers which use encoder embeddings and caches have been setup.

  • input_pos (Optional[torch.Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. This parameter is required during inference if caches have been setup. Default is None.

Returns:

output tensor with shape [b x s x v] or a list of layer

output tensors defined by output_hidden_states with the final output tensor appended to the list.

Return type:

Union[torch.Tensor, List[torch.Tensor]]

Note

At the very first step of inference, when the model is provided with a prompt, input_pos should contain the positions of all of the tokens in the prompt. For a single-batch prompt, or a batch of prompts with identical lengths, this will be torch.arange(prompt_length). For a batch of varying-length prompts, shorter prompts are left-padded and position ids are correspondingly right-shifted, thus positional ids should be of shape [b, padded_prompt_length]. This is because we will need to retrieve the positional embeddings for each input id. In the subsequent steps, if the model has been setup with KV-caches, input_pos will contain the position(s) of the current token(s) torch.tensor([padded_prompt_length]). Otherwise, input_pos will contain all the position ids up to the current token.

Shape notation:
  • b: batch size

  • s: token sequence length

  • s_e: encoder sequence length

  • v: vocab size

  • d: token embed dim

  • d_e: encoder embed dim

  • m_s: max seq len

reset_caches()[source]

Reset the key value caches.

set_num_output_chunks(num_output_chunks: int) None[source]

Used to save memory in combination with CEWithChunkedOutputLoss. This should be called before the first forward pass, in the recipe.

setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: Optional[int] = None, decoder_max_seq_len: Optional[int] = None)[source]

Sets up key-value attention caches for inference. For each layer in self.layers: - TransformerSelfAttentionLayer will use decoder_max_seq_len. - TransformerCrossAttentionLayer will use encoder_max_seq_len. - FusionLayer will use both decoder_max_seq_len and encoder_max_seq_len.

Parameters:
  • batch_size (int) – batch size for the caches.

  • dtype (torch.dpython:type) – dtype for the caches.

  • encoder_max_seq_len (Optional[int]) – maximum encoder cache sequence length.

  • decoder_max_seq_len (Optional[int]) – maximum decoder cache sequence length.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources