Shortcuts

TransformerDecoder

class torchtune.modules.TransformerDecoder(*, tok_embeddings: Embedding, layers: Union[Module, List[Module], ModuleList], max_seq_len: int, num_heads: int, head_dim: int, norm: Module, output: Union[Linear, Callable], num_layers: Optional[int] = None, output_hidden_states: Optional[List[int]] = None)[source]

Transformer Decoder derived from the Llama2 architecture.

Parameters:
  • tok_embeddings (nn.Embedding) – PyTorch embedding layer, to be used to move tokens to an embedding space.

  • layers (Union[nn.Module, List[nn.Module], nn.ModuleList]) – A single transformer Decoder layer, an nn.ModuleList of layers or a list of layers. It is recommended to use an nn.ModuleList.

  • max_seq_len (int) – maximum sequence length the model will be run with, as used by KVCache()

  • num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value. This is used to setup the KVCache()

  • head_dim (int) – embedding dimension for each head in self-attention. This is used to setup the KVCache()

  • norm (nn.Module) – Callable that applies normalization to the output of the decoder, before final MLP.

  • output (Union[nn.Linear, Callable]) – Callable that applies a linear transformation to the output of the decoder.

  • num_layers (Optional[int]) – Number of Transformer Decoder layers, only define when layers is not a list.

  • output_hidden_states (Optional[List[int]]) – List of layers (indices) to include in the output

Raises:

Note

Arg values are checked for correctness (eg: attn_dropout belongs to [0,1]) in the module where they are used. This helps reduces the number of raise statements in code and improves readability.

caches_are_enabled() bool[source]

Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant attention modules will be “enabled” and all forward passes will update the caches. This behaviour can be disabled without altering the state of the KV-caches by “disabling” the KV-caches using torchtune.modules.common_utils.disable_kv_cache(), upon which caches_are_enabled would return False.

caches_are_setup() bool[source]

Check if the key value caches are setup. This means setup_caches has been called, and the relevant attention modules in the model have created their KVCache.

chunked_output(last_hidden_state: Tensor) List[Tensor][source]

Apply output projection in chunks. This should be applied in conjunction with CEWithChunkedOutputLoss as upcasting to fp32 is done there.

To use this method, you should first call set_num_output_chunks().

Parameters:

last_hidden_state (torch.Tensor) – last hidden state of the decoder, having shape [b, seq_len, embed_dim].

Returns:

List of num_chunks output tensors, each with shape

[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.

Return type:

List[torch.Tensor]

forward(tokens: Tensor, *, mask: Optional[Tensor] = None, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Union[Tensor, List[Tensor]][source]
Parameters:
  • tokens (torch.Tensor) – input tensor with shape [b x s]

  • mask (Optional[_MaskType]) –

    Used to mask the scores after the query-key multiplication and before the softmax. This parameter is required during inference if caches have been setup. Either:

    A boolean tensor with shape [b x s x s], [b x s x self.encoder_max_cache_seq_len], or [b x s x self.encoder_max_cache_seq_len] if using KV-cacheing with encoder/decoder layers. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default.

    A BlockMask for document masking in a packed sequence created via create_block_mask. We use flex_attention() when computing attention with block masks. Default is None.

  • encoder_input (Optional[torch.Tensor]) – Optional input embeds from the encoder. Shape [b x s_e x d_e]

  • encoder_mask (Optional[torch.Tensor]) – Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape [b x s x s_e]. Default is None, but this is required during inference if the model has been setup with any layers which use encoder embeddings and caches have been setup.

  • input_pos (Optional[torch.Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. This parameter is required during inference if caches have been setup. Default is None.

Returns:

output tensor with shape [b x s x v] or a list of layer

output tensors defined by output_hidden_states with the final output tensor appended to the list.

Return type:

Union[torch.Tensor, List[torch.Tensor]]

Note

At the very first step of inference, when the model is provided with a prompt, input_pos should contain the positions of all of the tokens in the prompt. For a single-batch prompt, or a batch of prompts with identical lengths, this will be torch.arange(prompt_length). For a batch of varying-length prompts, shorter prompts are left-padded and position ids are correspondingly right-shifted, thus positional ids should be of shape [b, padded_prompt_length]. This is because we will need to retrieve the positional embeddings for each input id. In the subsequent steps, if the model has been setup with KV-caches, input_pos will contain the position(s) of the current token(s) torch.tensor([padded_prompt_length]). Otherwise, input_pos will contain all the position ids up to the current token.

Shape notation:
  • b: batch size

  • s: token sequence length

  • s_e: encoder sequence length

  • v: vocab size

  • d: token embed dim

  • d_e: encoder embed dim

  • m_s: max seq len

reset_caches()[source]

Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, without deleting or reallocating cache tensors.

Raises:

RuntimeError – if KV-caches are not setup. Use setup_caches() to setup caches first.

set_num_output_chunks(num_output_chunks: int) None[source]

Used to save memory in combination with CEWithChunkedOutputLoss. This should be called before the first forward pass, in the recipe.

setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: Optional[int] = None, decoder_max_seq_len: Optional[int] = None)[source]
Sets up key-value attention caches for inference. For each layer in self.layers:
Parameters:
  • batch_size (int) – batch size for the caches.

  • dtype (torch.dpython:type) – dtype for the caches.

  • encoder_max_seq_len (Optional[int]) – maximum encoder cache sequence length.

  • decoder_max_seq_len (Optional[int]) – maximum decoder cache sequence length.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources