Shortcuts

get_causal_mask_from_padding_mask

torchtune.generation.get_causal_mask_from_padding_mask(padding_mask: Tensor, target_seq_len: Optional[int] = None) Tensor[source]

Converts a padding mask of shape [bsz, seq_len] to a [bsz, seq_len, seq_len] causal attention mask suitable for consumption by scaled_dot_product_attention(). If target_seq_len is provided, this will return a mask of shape [bsz, seq_len, target_seq_len]. This is useful when generating masks for static KV caches where the maximum length the caches have been setup with are longer than the current sequence.

Parameters:
  • padding_mask (torch.Tensor) – Boolean tensor where False indicates the corresponding token in the sequence is a padding token and should be masked out in attention, with shape [bsz x seq_length]

  • target_seq_len (Optional[int]) – target sequence length to create attention mask with. Default None.

Returns:

Boolean causal mask with shape
  • [bsz, seq_length, seq_length] or

  • [bsz, seq_length, target_seq_len] if target_seq_len was specified.

Return type:

torch.Tensor

Raises:

AssertionError – if target_seq_len > seq_len, the sequence length of the padding mask.

Example

>>> padding_mask = torch.tensor([[False, True, True, True]])
>>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5)
tensor([[[ True, False, False, False, False],
          [False,  True, False, False, False],
          [False,  True,  True, False, False],
          [False,  True,  True,  True, False]]])
])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources