Shortcuts

torch.nn.attention.bias.CausalBias

class torch.nn.attention.bias.CausalBias(variant, seq_len_q, seq_len_kv)[source][source]

A bias representing causal attention patterns. For an overview of the bias structure, see the CausalVariant enum.

This class is used for defining causal (triangular) attention biases. For construing the bias, there exist two factory functions: causal_upper_left() and causal_lower_right().

Example:

from torch.nn.attention.bias import causal_lower_right

bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8

# Create a lower-right causal bias
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)

q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)

out = F.scaled_dot_product_attention(q, k, v, attn_bias)

Warning

This class is a prototype and subject to change.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources