torch.nn.functional.scaled_dot_product_attention¶
- torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) Tensor: ¶
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be specified as a keyword argument.
# Efficient implementation equivalent to the following: def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value
Warning
This function is beta and subject to change.
Warning
This function always applies dropout according to the specified
dropout_p
argument. To disable dropout during evaluation, be sure to pass a value of0.0
when the module that makes the function call is not in training mode.For example:
class MyModel(nn.Module): def __init__(self, p=0.5): super().__init__() self.p = p def forward(self, ...): return F.scaled_dot_product_attention(..., dropout_p=(self.p if self.training else 0.0))
Note
There are currently three supported implementations of scaled dot product attention:
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
A PyTorch implementation defined in C++ matching the above formulation
The function may call optimized kernels for improved performance when using the CUDA backend. For all other backends, the PyTorch implementation will be used.
All implementations are enabled by default. Scaled dot product attention attempts to automatically select the most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation is used, the following functions are provided for enabling and disabling implementations. The context manager is the preferred mechanism:
torch.nn.attention.sdpa_kernel()
: A context manager used to enable or disable any of the implementations.torch.backends.cuda.enable_flash_sdp()
: Globally enables or disables FlashAttention.torch.backends.cuda.enable_mem_efficient_sdp()
: Globally enables or disables Memory-Efficient Attention.torch.backends.cuda.enable_math_sdp()
: Globally enables or disables the PyTorch C++ implementation.
Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation, disable the PyTorch C++ implementation using
torch.nn.attention.sdpa_kernel()
. In the event that a fused implementation is not available, a warning will be raised with the reasons why the fused implementation cannot run.Due to the nature of fusing floating point operations, the output of this function may be different depending on what backend kernel is chosen. The c++ implementation supports torch.float64 and can be used when higher precision is required. For more information please see Numerical accuracy
Note
In some circumstances when given tensors on a CUDA device and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting
torch.backends.cudnn.deterministic = True
. See Reproducibility for more information.- Parameters
query (Tensor) – Query tensor; shape .
key (Tensor) – Key tensor; shape .
value (Tensor) – Value tensor; shape .
attn_mask (optional Tensor) – Attention mask; shape must be broadcastable to the shape of attention weights, which is . Two types of masks are supported. A boolean mask where a value of True indicates that the element should take part in attention. A float mask of the same type as query, key, value that is added to the attention score.
dropout_p (float) – Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool) – If set to true, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment (see
torch.nn.attention.bias.CausalBias
) when the mask is a non-square matrix. An error is thrown if both attn_mask and is_causal are set.scale (optional python:float, keyword-only) – Scaling factor applied prior to softmax. If None, the default value is set to .
- Returns
Attention output; shape .
- Return type
output (Tensor)
- Shape legend:
Examples
>>> # Optionally use the context manager to ensure one of the fused kernels is run >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> with torch.backends.cuda.sdp_kernel(enable_math=False): >>> F.scaled_dot_product_attention(query,key,value)