torch.nn.attention.flex_attention¶
- torch.nn.attention.flex_attention.flex_attention(query, key, value, score_mod=None, block_mask=None, scale=None, enable_gqa=False, return_lse=False, kernel_options=None)[source][source]¶
This function implements scaled dot product attention with an arbitrary attention score modification function.
This function computes the scaled dot product attention between query, key, and value tensors with a user-defined attention score modification function. The attention score modification function will be applied after the attention scores have been calculated between the query and key tensors. The attention scores are calculated as follows:
The
score_mod
function should have the following signature:def score_mod( score: Tensor, batch: Tensor, head: Tensor, q_idx: Tensor, k_idx: Tensor ) -> Tensor:
- Where:
score
: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors.batch
,head
,q_idx
,k_idx
: Scalar tensors indicating the batch index, query head index, query index, and key/value index, respectively. These should have thetorch.int
data type and be located on the same device as the score tensor.
- Parameters
query (Tensor) – Query tensor; shape .
key (Tensor) – Key tensor; shape .
value (Tensor) – Value tensor; shape .
score_mod (Optional[Callable]) – Function to modify attention scores. By default no score_mod is applied.
block_mask (Optional[BlockMask]) – BlockMask object that controls the blocksparsity pattern of the attention.
scale (Optional[float]) – Scaling factor applied prior to softmax. If none, the default value is set to .
enable_gqa (bool) – If set to True, enables Grouped Query Attention (GQA) and broadcasts key/value heads to query heads.
return_lse (bool) – Whether to return the logsumexp of the attention scores. Default is False.
kernel_options (Optional[Dict[str, Any]]) – Options to pass into the Triton kernels.
- Returns
Attention output; shape .
- Return type
output (Tensor)
- Shape legend:
Warning
torch.nn.attention.flex_attention is a prototype feature in PyTorch. Please look forward to a more stable implementation in a future version of PyTorch. Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
BlockMask Utilities¶
- torch.nn.attention.flex_attention.create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device='cuda', BLOCK_SIZE=128, _compile=False)[source][source]¶
This function creates a block mask tuple from a mask_mod function.
- Parameters
mask_mod (Callable) – mask_mod function. This is a callable that defines the masking pattern for the attention mechanism. It takes four arguments: b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index). It should return a boolean tensor indicating which attention connections are allowed (True) or masked out (False).
B (int) – Batch size.
H (int) – Number of query heads.
Q_LEN (int) – Sequence length of query.
KV_LEN (int) – Sequence length of key/value.
device (str) – Device to run the mask creation on.
BLOCK_SIZE (int or Tuple[int, int]) – Block size for the block mask. If a single int is provided it is used for both query and key/value.
- Returns
A BlockMask object that contains the block mask information.
- Return type
- Example Usage:
def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda") query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) output = flex_attention(query, key, value, block_mask=block_mask)
- torch.nn.attention.flex_attention.create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device='cuda')[source][source]¶
This function creates a mask tensor from a mod_fn function.
- Parameters
- Returns
A mask tensor with shape (B, H, M, N).
- Return type
mask (Tensor)
- torch.nn.attention.flex_attention.create_nested_block_mask(mask_mod, B, H, q_nt, kv_nt=None, BLOCK_SIZE=128, _compile=False)[source][source]¶
This function creates a nested tensor compatible block mask tuple from a mask_mod function. The returned BlockMask will be on the device specified by the input nested tensor.
- Parameters
mask_mod (Callable) – mask_mod function. This is a callable that defines the masking pattern for the attention mechanism. It takes four arguments: b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index). It should return a boolean tensor indicating which attention connections are allowed (True) or masked out (False).
B (int) – Batch size.
H (int) – Number of query heads.
q_nt (torch.Tensor) – Jagged layout nested tensor (NJT) that defines the sequence length structure for query. The block mask will be constructed to operate on a “stacked sequence” of length
sum(S)
for sequence lengthS
from the NJT.kv_nt (torch.Tensor) – Jagged layout nested tensor (NJT) that defines the sequence length structure for key / value, allowing for cross attention. The block mask will be constructed to operate on a “stacked sequence” of length
sum(S)
for sequence lengthS
from the NJT. If this is None,q_nt
is used to define the structure for key / value as well. Default: NoneBLOCK_SIZE (int or Tuple[int, int]) – Block size for the block mask. If a single int is provided it is used for both query and key/value.
- Returns
A BlockMask object that contains the block mask information.
- Return type
- Example Usage:
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True) output = flex_attention(query, key, value, block_mask=block_mask)
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx # cross attention case: pass both query and key/value NJTs block_mask = create_nested_block_mask(causal_mask, 1, 1, query, key, _compile=True) output = flex_attention(query, key, value, block_mask=block_mask)
- torch.nn.attention.flex_attention.and_masks(*mask_mods)[source][source]¶
Returns a mask_mod that’s the intersection of provided mask_mods
BlockMask¶
- class torch.nn.attention.flex_attention.BlockMask(seq_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, q_num_blocks, q_indices, full_q_num_blocks, full_q_indices, BLOCK_SIZE, mask_mod)[source][source]¶
BlockMask is our format for representing a block-sparse attention mask. It is somewhat of a cross in-between BCSR and a non-sparse format.
Basics¶
A block-sparse mask means that instead of representing the sparsity of individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is considered sparse only if every element within that block is sparse. This aligns well with hardware, which generally expects to perform contiguous loads and computation.
This format is primarily optimized for 1. simplicity, and 2. kernel efficiency. Notably, it is not optimized for size, as this mask is always reduced by a factor of KV_BLOCK_SIZE * Q_BLOCK_SIZE. If the size is a concern, the tensors can be reduced in size by increasing the block size.
The essentials of our format are:
num_blocks_in_row: Tensor[ROWS]: Describes the number of blocks present in each row.
col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]: col_indices[i] is the sequence of block positions for row i. The values of this row after col_indices[i][num_blocks_in_row[i]] are undefined.
For example, to reconstruct the original tensor from this format:
dense_mask = torch.zeros(ROWS, COLS) for row in range(ROWS): for block_idx in range(num_blocks_in_row[row]): dense_mask[row, col_indices[row, block_idx]] = 1
Notably, this format makes it easier to implement a reduction along the rows of the mask.
Details¶
The basics of our format require only kv_num_blocks and kv_indices. But, we have up to 8 tensors on this object. This represents 4 pairs:
1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as we reduce along the KV dimension.
2. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and purely an optimization. As it turns out, applying masking to every block is quite expensive! If we specifically know which blocks are “full” and don’t require masking at all, then we can skip applying mask_mod to these blocks. This requires the user to split out a separate mask_mod from the score_mod. For causal masks, this is about a 15% speedup.
3. [GENERATED] (q_num_blocks, q_indices): Required for the backwards pass, as computing dKV requires iterating along the mask along the Q dimension. These are autogenerated from 1.
4. [GENERATED] (full_q_num_blocks, full_q_indices): Same as above, but for the backwards pass. These are autogenerated from 2.
- as_tuple(flatten=True)[source][source]¶
Returns a tuple of the attributes of the BlockMask.
- Parameters
flatten (bool) – If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE)
- classmethod from_kv_blocks(kv_num_blocks, kv_indices, full_kv_num_blocks=None, full_kv_indices=None, BLOCK_SIZE=128, mask_mod=None, seq_lengths=None)[source][source]¶
Creates a BlockMask instance from key-value block information.
- Parameters
kv_num_blocks (Tensor) – Number of kv_blocks in each Q_BLOCK_SIZE row tile.
kv_indices (Tensor) – Indices of key-value blocks in each Q_BLOCK_SIZE row tile.
full_kv_num_blocks (Optional[Tensor]) – Number of full kv_blocks in each Q_BLOCK_SIZE row tile.
full_kv_indices (Optional[Tensor]) – Indices of full key-value blocks in each Q_BLOCK_SIZE row tile.
BLOCK_SIZE (Union[int, Tuple[int, int]]) – Size of KV_BLOCK_SIZE x Q_BLOCK_SIZE tiles.
mask_mod (Optional[Callable]) – Function to modify the mask.
- Returns
Instance with full Q information generated via _transposed_ordered
- Return type
- Raises
RuntimeError – If kv_indices has < 2 dimensions.
AssertionError – If only one of full_kv_* args is provided.
- property shape¶
- sparsity()[source][source]¶
Computes the percentage of blocks that are sparse (i.e. not computed)
- Return type
- to(device)[source][source]¶
Moves the BlockMask to the specified device.
- Parameters
device (torch.device or str) – The target device to move the BlockMask to. Can be a torch.device object or a string (e.g., ‘cpu’, ‘cuda:0’).
- Returns
A new BlockMask instance with all tensor components moved to the specified device.
- Return type
Note
This method does not modify the original BlockMask in-place. Instead, it returns a new BlockMask instance where invidual tensor attributes may or may not be moved to the specified device, depending on their current device placement.