Shortcuts

TransformerSelfAttentionLayer

class torchtune.modules.TransformerSelfAttentionLayer(attn: MultiHeadAttention, mlp: Module, *, sa_norm: Optional[Module] = None, mlp_norm: Optional[Module] = None, sa_scale: Optional[Module] = None, mlp_scale: Optional[Module] = None)[source]

Transformer layer derived from the Llama2 model. Normalization is applied before the attention and FF layer.

Parameters:
  • attn (MultiHeadAttention) – Attention module.

  • mlp (nn.Module) – Feed-forward module.

  • sa_norm (Optional[nn.Module]) – Normalization to be applied before self-attention.

  • mlp_norm (Optional[nn.Module]) – Normalization to be applied before the feed-forward layer.

  • sa_scale (Optional[nn.Module]) – Module to scale self-attention output.

  • mlp_scale (Optional[nn.Module]) – Module to scale the feed-forward output.

caches_are_enabled() bool[source]

Checks if the key value caches on self.attn are enabled. See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`.

caches_are_setup() bool[source]

Check if the key value caches are setup on self.attn. See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`.

forward(x: Tensor, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None, **kwargs: Dict) Tensor[source]
Parameters:
  • x (torch.Tensor) – input tensor with shape [batch_size x seq_length x embed_dim]

  • mask (Optional[_MaskType]) –

    Used to mask the scores after the query-key multiplication and before the softmax. Either:

    A boolean tensor with shape [b x s x s], [b x s x self.encoder_max_cache_seq_len], or [b x s x self.encoder_max_cache_seq_len] if using KV-cacheing with encoder/decoder layers. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default.

    A BlockMask for document masking in a packed sequence created via create_block_mask. We use flex_attention() when computing attention with block masks. Default is None.

  • input_pos (Optional[torch.Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None.

  • **kwargs (Dict) – transformer layer inputs not relevant to self attention.

Returns:

output tensor with same shape as input

[batch_size x seq_length x embed_dim]

Return type:

torch.Tensor

reset_cache()[source]

Reset the key value caches.

setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: int, decoder_max_seq_len: int) None[source]

Setup key value caches for attention calculation.

Parameters:
  • batch_size (int) – batch size for the caches.

  • dtype (torch.dpython:type) – dtype for the caches.

  • encoder_max_seq_len (int) – this parameter is ignored in this layer.

  • decoder_max_seq_len (int) – maximum cache sequence length.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources