TransformerCrossAttentionLayer¶
- class torchtune.modules.TransformerCrossAttentionLayer(attn: MultiHeadAttention, mlp: Module, *, ca_norm: Optional[Module] = None, mlp_norm: Optional[Module] = None, ca_scale: Optional[Module] = None, mlp_scale: Optional[Module] = None)[source]¶
Cross attention Transformer layer following the same conventions as the TransformerSelfAttentionLayer. Normalization is applied before the attention and FF layer.
- Parameters:
attn (MultiHeadAttention) – Attention module.
mlp (nn.Module) – Feed-forward module.
ca_norm (Optional[nn.Module]) – Normalization to be applied before cross-attention.
mlp_norm (Optional[nn.Module]) – Normalization to be applied before the feed-forward layer.
ca_scale (Optional[nn.Module]) – Module to scale cross-attention output.
mlp_scale (Optional[nn.Module]) – Module to scale the feed-forward output.
- Raises:
AssertionError – if attn.pos_embeddings is set.
- caches_are_enabled() bool [source]¶
Checks if the key value caches on
self.attn
are enabled. See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`.
- caches_are_setup() bool [source]¶
Check if the key value caches are setup on
self.attn
. See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`.
- forward(x: Tensor, *, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, **kwargs: Dict) Tensor [source]¶
- Parameters:
x (torch.Tensor) – input tensor with shape [batch_size x seq_length x embed_dim]
encoder_input (Optional[torch.Tensor]) – Optional input embeds from the encoder. Shape [batch_size x token_sequence x embed_dim]
encoder_mask (Optional[torch.Tensor]) – Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape [batch_size x token_sequence x embed_sequence]. Default is None.
**kwargs (Dict) – transformer layer inputs not relevant to self attention.
- Returns:
- output tensor with same shape as input
[batch_size x seq_length x embed_dim]
- Return type: