TransformerDecoder¶
- class torchtune.modules.TransformerDecoder(*, tok_embeddings: Embedding, layers: Union[Module, List[Module], ModuleList], max_seq_len: int, num_heads: int, head_dim: int, norm: Module, output: Union[Linear, Callable], num_layers: Optional[int] = None, output_hidden_states: Optional[List[int]] = None)[source]¶
Transformer Decoder derived from the Llama2 architecture.
- Parameters:
tok_embeddings (nn.Embedding) – PyTorch embedding layer, to be used to move tokens to an embedding space.
layers (Union[nn.Module, List[nn.Module], nn.ModuleList]) – A single transformer Decoder layer, an nn.ModuleList of layers or a list of layers. It is recommended to use an nn.ModuleList.
max_seq_len (int) – maximum sequence length the model will be run with, as used by
KVCache()
num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value. This is used to setup the
KVCache()
head_dim (int) – embedding dimension for each head in self-attention. This is used to setup the
KVCache()
norm (nn.Module) – Callable that applies normalization to the output of the decoder, before final MLP.
output (Union[nn.Linear, Callable]) – Callable that applies a linear transformation to the output of the decoder.
num_layers (Optional[int]) – Number of Transformer Decoder layers, only define when layers is not a list.
output_hidden_states (Optional[List[int]]) – List of layers (indices) to include in the output
- Raises:
AssertionError – num_layers is set and layer is a list
AssertionError – num_layers is not set and layer is an nn.Module
Note
Arg values are checked for correctness (eg:
attn_dropout
belongs to [0,1]) in the module where they are used. This helps reduces the number of raise statements in code and improves readability.- caches_are_enabled() bool [source]¶
Check if the key value caches are setup. This is useful to efficient inference.
- chunked_output(last_hidden_state: Tensor) List[Tensor] [source]¶
Apply output projection in chunks. This should be applied in conjunction with
CEWithChunkedOutputLoss
as upcasting to fp32 is done there.To use this method, you should first call
set_num_output_chunks()
.- Parameters:
last_hidden_state (torch.Tensor) – last hidden state of the decoder, having shape [b, seq_len, embed_dim].
- Returns:
- List of num_chunks output tensors, each with shape
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
- Return type:
List[torch.Tensor]
- forward(tokens: Tensor, *, mask: Optional[Tensor] = None, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Union[Tensor, List[Tensor]] [source]¶
- Parameters:
tokens (torch.Tensor) – input tensor with shape
[b x s]
mask (Optional[_MaskType]) –
Used to mask the scores after the query-key multiplication and before the softmax. This parameter is required during inference if caches have been setup. Either:
A boolean tensor with shape
[b x s x s]
,[b x s x self.encoder_max_cache_seq_len]
, or[b x s x self.encoder_max_cache_seq_len]
if using KV-cacheing with encoder/decoder layers. A value of True in rowi
and columnj
means tokeni
attends to tokenj
. A value of False means tokeni
does not attend to tokenj
. If no mask is specified, a causal mask is used by default.A
BlockMask
for document masking in a packed sequence created via create_block_mask. We useflex_attention()
when computing attention with block masks. Default is None.encoder_input (Optional[torch.Tensor]) – Optional input embeds from the encoder. Shape
[b x s_e x d_e]
encoder_mask (Optional[torch.Tensor]) – Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position
i,j
means tokeni
can attend to embeddingj
in the decoder. Mask has shape[b x s x s_e]
. Default is None, but this is required during inference if the model has been setup with any layers which use encoder embeddings and caches have been setup.input_pos (Optional[torch.Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape
[b x s]
. During inference, this indicates the position of the current token. This parameter is required during inference if caches have been setup. Default is None.
- Returns:
- output tensor with shape
[b x s x v]
or a list of layer output tensors defined by
output_hidden_states
with the final output tensor appended to the list.
- output tensor with shape
- Return type:
Union[torch.Tensor, List[torch.Tensor]]
Note
At the very first step of inference, when the model is provided with a prompt,
input_pos
should contain the positions of all of the tokens in the prompt. For a single-batch prompt, or a batch of prompts with identical lengths, this will betorch.arange(prompt_length)
. For a batch of varying-length prompts, shorter prompts are left-padded and position ids are correspondingly right-shifted, thus positional ids should be of shape[b, padded_prompt_length]
. This is because we will need to retrieve the positional embeddings for each input id. In the subsequent steps, if the model has been setup with KV-caches,input_pos
will contain the position(s) of the current token(s)torch.tensor([padded_prompt_length])
. Otherwise,input_pos
will contain all the position ids up to the current token.- Shape notation:
b: batch size
s: token sequence length
s_e: encoder sequence length
v: vocab size
d: token embed dim
d_e: encoder embed dim
m_s: max seq len
- set_num_output_chunks(num_output_chunks: int) None [source]¶
Used to save memory in combination with
CEWithChunkedOutputLoss
. This should be called before the first forward pass, in the recipe.
- setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: Optional[int] = None, decoder_max_seq_len: Optional[int] = None)[source]¶
Sets up key-value attention caches for inference. For each layer in
self.layers
: -TransformerSelfAttentionLayer
will usedecoder_max_seq_len
. -TransformerCrossAttentionLayer
will useencoder_max_seq_len
. -FusionLayer
will use bothdecoder_max_seq_len
andencoder_max_seq_len
.