get_causal_mask_from_padding_mask¶
- torchtune.generation.get_causal_mask_from_padding_mask(padding_mask: Tensor, target_seq_len: Optional[int] = None) Tensor [source]¶
Converts a padding mask of shape
[bsz, seq_len]
to a[bsz, seq_len, seq_len]
causal attention mask suitable for consumption byscaled_dot_product_attention()
. Iftarget_seq_len
is provided, this will return a mask of shape[bsz, seq_len, target_seq_len]
. This is useful when generating masks for static KV caches where the maximum length the caches have been setup with are longer than the current sequence.- Parameters:
padding_mask (torch.Tensor) – Boolean tensor where False indicates the corresponding token in the sequence is a padding token and should be masked out in attention, with shape [bsz x seq_length]
target_seq_len (Optional[int]) – target sequence length to create attention mask with. Default None.
- Returns:
- Boolean causal mask with shape
[bsz, seq_length, seq_length] or
[bsz, seq_length, target_seq_len] if
target_seq_len
was specified.
- Return type:
- Raises:
AssertionError – if
target_seq_len > seq_len
, the sequence length of the padding mask.
Example
>>> padding_mask = torch.tensor([[False, True, True, True]]) >>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5) tensor([[[ True, False, False, False, False], [False, True, False, False, False], [False, True, True, False, False], [False, True, True, True, False]]]) ])