Shortcuts

TransformerCrossAttentionLayer

class torchtune.modules.TransformerCrossAttentionLayer(attn: MultiHeadAttention, mlp: Module, *, ca_norm: Optional[Module] = None, mlp_norm: Optional[Module] = None, ca_scale: Optional[Module] = None, mlp_scale: Optional[Module] = None)[source]

Cross attention Transformer layer following the same conventions as the TransformerSelfAttentionLayer. Normalization is applied before the attention and FF layer.

Parameters:
  • attn (MultiHeadAttention) – Attention module.

  • mlp (nn.Module) – Feed-forward module.

  • ca_norm (Optional[nn.Module]) – Normalization to be applied before cross-attention.

  • mlp_norm (Optional[nn.Module]) – Normalization to be applied before the feed-forward layer.

  • ca_scale (Optional[nn.Module]) – Module to scale cross-attention output.

  • mlp_scale (Optional[nn.Module]) – Module to scale the feed-forward output.

Raises:

AssertionError – if attn.pos_embeddings is set.

property cache_enabled: bool

Check if the key value caches are setup.

forward(x: Tensor, *, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, **kwargs: Dict) Tensor[source]
Parameters:
  • x (torch.Tensor) – input tensor with shape [batch_size x seq_length x embed_dim]

  • encoder_input (Optional[torch.Tensor]) – Optional input embeds from the encoder. Shape [batch_size x token_sequence x embed_dim]

  • encoder_mask (Optional[torch.Tensor]) – Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape [batch_size x token_sequence x embed_sequence]. Default is None.

  • **kwargs (Dict) – transformer layer inputs not relevant to self attention.

Returns:

output tensor with same shape as input

[batch_size x seq_length x embed_dim]

Return type:

torch.Tensor

reset_cache()[source]

Reset the key value caches.

setup_cache(batch_size: int, dtype: dtype, *, encoder_max_seq_len: int, decoder_max_seq_len: int) None[source]

Setup key value caches for attention calculation.

Parameters:
  • batch_size (int) – batch size for the caches.

  • dtype (torch.dpython:type) – dtype for the caches.

  • encoder_max_seq_len (int) – maximum cache sequence length.

  • decoder_max_seq_len (int) – this parameter is ignored in this layer.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources