TransformerSelfAttentionLayer¶
- class torchtune.modules.TransformerSelfAttentionLayer(attn: MultiHeadAttention, mlp: Module, *, sa_norm: Optional[Module] = None, mlp_norm: Optional[Module] = None, sa_scale: Optional[Module] = None, mlp_scale: Optional[Module] = None)[source]¶
Transformer layer derived from the Llama2 model. Normalization is applied before the attention and FF layer.
- Parameters:
attn (MultiHeadAttention) – Attention module.
mlp (nn.Module) – Feed-forward module.
sa_norm (Optional[nn.Module]) – Normalization to be applied before self-attention.
mlp_norm (Optional[nn.Module]) – Normalization to be applied before the feed-forward layer.
sa_scale (Optional[nn.Module]) – Module to scale self-attention output.
mlp_scale (Optional[nn.Module]) – Module to scale the feed-forward output.
- caches_are_enabled() bool [source]¶
Checks if the key value caches on
self.attn
are enabled. See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`.
- caches_are_setup() bool [source]¶
Check if the key value caches are setup on
self.attn
. See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`.
- forward(x: Tensor, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None, **kwargs: Dict) Tensor [source]¶
- Parameters:
x (torch.Tensor) – input tensor with shape [batch_size x seq_length x embed_dim]
mask (Optional[_MaskType]) –
Used to mask the scores after the query-key multiplication and before the softmax. Either:
A boolean tensor with shape
[b x s x s]
,[b x s x self.encoder_max_cache_seq_len]
, or[b x s x self.encoder_max_cache_seq_len]
if using KV-cacheing with encoder/decoder layers. A value of True in rowi
and columnj
means tokeni
attends to tokenj
. A value of False means tokeni
does not attend to tokenj
. If no mask is specified, a causal mask is used by default.A
BlockMask
for document masking in a packed sequence created via create_block_mask. We useflex_attention()
when computing attention with block masks. Default is None.input_pos (Optional[torch.Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None.
**kwargs (Dict) – transformer layer inputs not relevant to self attention.
- Returns:
- output tensor with same shape as input
[batch_size x seq_length x embed_dim]
- Return type: