Shortcuts

CausalVariant

class torch.nn.attention.bias.CausalVariant(value)[source][source]

Enum for causal variants used in attention mechanisms.

Defines two types of causal biases:

UPPER_LEFT: Represents upper-left triangular bias for standard causal attention. The equivalent pytorch code for constructing this bias is:

torch.tril(torch.ones(size, dtype=torch.bool))

For instance, with shape=(3,4), the materialized bias tensor will be:

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0]]

LOWER_RIGHT: Represents lower-right triangular bias, the include values are aligned to the lower right corner of the matrix.

The equivalent pytorch code for constructing this bias is:

diagonal_offset = size[1] - size[0]
torch.tril(
    torch.ones(size, dtype=torch.bool),
    diagonal=diagonal_offset,
)

For instance, with shape=(3,4), the materialized bias tensor will be:

[[1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]

Note that these variants are equivalent to each other when the sequence lengths of the query and key/value tensors are equal since the triangular matrix is square.

Warning

This enum is a prototype and subject to change.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources