TransformerDecoder¶
- class torchtune.modules.TransformerDecoder(*, tok_embeddings: Embedding, layers: Union[Module, List[Module], ModuleList], max_seq_len: int, num_heads: int, head_dim: int, norm: Module, output: Union[Linear, Callable], num_layers: Optional[int] = None, output_hidden_states: Optional[List[int]] = None)[source]¶
Transformer Decoder derived from the Llama2 architecture.
- Parameters:
tok_embeddings (nn.Embedding) – PyTorch embedding layer, to be used to move tokens to an embedding space.
layers (Union[nn.Module, List[nn.Module], nn.ModuleList]) – A single transformer Decoder layer, an nn.ModuleList of layers or a list of layers. It is recommended to use an nn.ModuleList.
max_seq_len (int) – maximum sequence length the model will be run with, as used by
KVCache()
num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value. This is used to setup the
KVCache()
head_dim (int) – embedding dimension for each head in self-attention. This is used to setup the
KVCache()
norm (nn.Module) – Callable that applies normalization to the output of the decoder, before final MLP.
output (Union[nn.Linear, Callable]) – Callable that applies a linear transformation to the output of the decoder.
num_layers (Optional[int]) – Number of Transformer Decoder layers, only define when layers is not a list.
output_hidden_states (Optional[List[int]]) – List of layers (indices) to include in the output
- Raises:
AssertionError – num_layers is set and layer is a list
AssertionError – num_layers is not set and layer is an nn.Module
Note
Arg values are checked for correctness (eg:
attn_dropout
belongs to [0,1]) in the module where they are used. This helps reduces the number of raise statements in code and improves readability.- caches_are_enabled() bool [source]¶
Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant attention modules will be “enabled” and all forward passes will update the caches. This behaviour can be disabled without altering the state of the KV-caches by “disabling” the KV-caches using
torchtune.modules.common_utils.disable_kv_cache()
, upon whichcaches_are_enabled
would return False.
- caches_are_setup() bool [source]¶
Check if the key value caches are setup. This means
setup_caches
has been called, and the relevant attention modules in the model have created theirKVCache
.
- chunked_output(last_hidden_state: Tensor) List[Tensor] [source]¶
Apply output projection in chunks. This should be applied in conjunction with
CEWithChunkedOutputLoss
as upcasting to fp32 is done there.To use this method, you should first call
set_num_output_chunks()
.- Parameters:
last_hidden_state (torch.Tensor) – last hidden state of the decoder, having shape [b, seq_len, embed_dim].
- Returns:
- List of num_chunks output tensors, each with shape
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
- Return type:
List[torch.Tensor]
- forward(tokens: Tensor, *, mask: Optional[Tensor] = None, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Union[Tensor, List[Tensor]] [source]¶
- Parameters:
tokens (torch.Tensor) – input tensor with shape
[b x s]
mask (Optional[_MaskType]) –
Used to mask the scores after the query-key multiplication and before the softmax. This parameter is required during inference if caches have been setup. Either:
A boolean tensor with shape
[b x s x s]
,[b x s x self.encoder_max_cache_seq_len]
, or[b x s x self.encoder_max_cache_seq_len]
if using KV-cacheing with encoder/decoder layers. A value of True in rowi
and columnj
means tokeni
attends to tokenj
. A value of False means tokeni
does not attend to tokenj
. If no mask is specified, a causal mask is used by default.A
BlockMask
for document masking in a packed sequence created via create_block_mask. We useflex_attention()
when computing attention with block masks. Default is None.encoder_input (Optional[torch.Tensor]) – Optional input embeds from the encoder. Shape
[b x s_e x d_e]
encoder_mask (Optional[torch.Tensor]) – Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position
i,j
means tokeni
can attend to embeddingj
in the decoder. Mask has shape[b x s x s_e]
. Default is None, but this is required during inference if the model has been setup with any layers which use encoder embeddings and caches have been setup.input_pos (Optional[torch.Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape
[b x s]
. During inference, this indicates the position of the current token. This parameter is required during inference if caches have been setup. Default is None.
- Returns:
- output tensor with shape
[b x s x v]
or a list of layer output tensors defined by
output_hidden_states
with the final output tensor appended to the list.
- output tensor with shape
- Return type:
Union[torch.Tensor, List[torch.Tensor]]
Note
At the very first step of inference, when the model is provided with a prompt,
input_pos
should contain the positions of all of the tokens in the prompt. For a single-batch prompt, or a batch of prompts with identical lengths, this will betorch.arange(prompt_length)
. For a batch of varying-length prompts, shorter prompts are left-padded and position ids are correspondingly right-shifted, thus positional ids should be of shape[b, padded_prompt_length]
. This is because we will need to retrieve the positional embeddings for each input id. In the subsequent steps, if the model has been setup with KV-caches,input_pos
will contain the position(s) of the current token(s)torch.tensor([padded_prompt_length])
. Otherwise,input_pos
will contain all the position ids up to the current token.- Shape notation:
b: batch size
s: token sequence length
s_e: encoder sequence length
v: vocab size
d: token embed dim
d_e: encoder embed dim
m_s: max seq len
- reset_caches()[source]¶
Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, without deleting or reallocating cache tensors.
- Raises:
RuntimeError – if KV-caches are not setup. Use
setup_caches()
to setup caches first.
- set_num_output_chunks(num_output_chunks: int) None [source]¶
Used to save memory in combination with
CEWithChunkedOutputLoss
. This should be called before the first forward pass, in the recipe.
- setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: Optional[int] = None, decoder_max_seq_len: Optional[int] = None)[source]¶
- Sets up key-value attention caches for inference. For each layer in
self.layers
: TransformerSelfAttentionLayer
will usedecoder_max_seq_len
.TransformerCrossAttentionLayer
will useencoder_max_seq_len
.FusionLayer
will usedecoder_max_seq_len
andencoder_max_seq_len
.
- Sets up key-value attention caches for inference. For each layer in