Note
Click here to download the full example code
(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)¶
Author: Driss Guessous
Summary¶
In this tutorial, we want to highlight a new torch.nn.functional
function
that can be helpful for implementing transformer architectures. The
function is named torch.nn.functional.scaled_dot_product_attention
.
For detailed description of the function, see the PyTorch documentation.
This function has already been incorporated into torch.nn.MultiheadAttention
and torch.nn.TransformerEncoderLayer
.
Overview¶
At a high level, this PyTorch function calculates the scaled dot product attention (SDPA) between query, key, and value according to the definition found in the paper Attention is all you need. While this function can be written in PyTorch using existing functions, a fused implementation can provide large performance benefits over a naive implementation.
Fused implementations¶
For CUDA tensor inputs, the function will dispatch into one of the following implementations:
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
A PyTorch implementation defined in C++
Note
This tutorial requires PyTorch 2.0.0 or later.
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[-1.3321, -0.3489, 0.3015, -0.3912, 0.9867, 0.3137, -0.0691,
-1.2593],
[-1.0882, 0.2506, 0.6491, 0.1360, 0.5238, -0.2448, -0.0820,
-0.6171],
[-1.0012, 0.3990, 0.6441, -0.0277, 0.5325, -0.2564, -0.0607,
-0.6404]],
[[ 0.6091, 0.0708, 0.6188, 0.3252, -0.1598, 0.4197, -0.2335,
0.0630],
[ 0.5285, 0.3890, -0.2649, 0.3706, -0.3839, 0.1963, -0.6242,
0.2312],
[ 0.4048, 0.0762, 0.3777, 0.4689, -0.2978, 0.2754, -0.6429,
0.1037]]], device='cuda:0')
Explicit Dispatcher Control¶
While the function will implicitly dispatch to one of the three implementations, the user can also explicitly control the dispatch via the use of a context manager. This context manager allows users to explicitly disable certain implementations. If a user wants to ensure the function is indeed using the fastest implementation for their specific inputs, the context manager can be used to sweep through measuring performance.
# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel
with sdpa_kernel(SDPBackend.MATH):
math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The math implementation runs in {math_time:.3f} microseconds")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
try:
flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
try:
efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 2274.495 microseconds
The math implementation runs in 19258.668 microseconds
The flash attention implementation runs in 2274.449 microseconds
The memory efficient implementation runs in 4325.576 microseconds
Hardware dependence¶
Depending on what machine you ran the above cell on and what hardware is available, your results might be different. - If you don’t have a GPU and are running on CPU then with FP32 the context manager will have no effect and all three runs should return similar timings. - Depending on what compute capability your graphics card supports flash attention or memory efficient might have failed.
Causal Self Attention¶
Below is an example implementation of a multi-headed causal self attention block inspired by Andrej Karpathy NanoGPT repository.
class CausalSelfAttention(nn.Module):
def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
super().__init__()
assert embed_dimension % num_heads == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
# output projection
self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
# regularization
self.dropout = dropout
self.resid_dropout = nn.Dropout(dropout)
self.num_heads = num_heads
self.embed_dimension = embed_dimension
# Perform causal masking
self.is_causal = is_causal
def forward(self, x):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query_projected = self.c_attn(x)
batch_size = query_projected.size(0)
embed_dim = query_projected.size(2)
head_dim = embed_dim // (self.num_heads * 3)
query, key, value = query_projected.chunk(3, -1)
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
if self.training:
dropout = self.dropout
is_causal = self.is_causal
else:
dropout = 0.0
is_causal = False
y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)
y = self.resid_dropout(self.c_proj(y))
return y
num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
(c_attn): Linear(in_features=512, out_features=1536, bias=False)
(c_proj): Linear(in_features=512, out_features=512, bias=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
NestedTensor
and Dense tensor support¶
SDPA supports both NestedTensor
and Dense tensor inputs. NestedTensors
handle the case where the input is a batch of variable length sequences
without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors
see
torch.nested and NestedTensors Tutorial.
import random
def generate_rand_batch(
batch_size,
max_sequence_len,
embed_dimension,
pad_percentage=None,
dtype=torch.float16,
device="cuda",
):
if not pad_percentage:
return (
torch.randn(
batch_size,
max_sequence_len,
embed_dimension,
dtype=dtype,
device=device,
),
None,
)
# Random sequence lengths
seq_len_list = [
int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
for _ in range(batch_size)
]
# Make random entry in the batch have max sequence length
seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
return (
torch.nested.nested_tensor(
[
torch.randn(seq_len, embed_dimension,
dtype=dtype, device=device)
for seq_len in seq_len_list
]
),
seq_len_list,
)
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
try:
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nested/__init__.py:220: UserWarning:
The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)
Random NT runs in 556.794 microseconds
Random Dense runs in 936.054 microseconds
Using SDPA with torch.compile
¶
With the release of PyTorch 2.0, a new feature called
torch.compile()
has been introduced, which can provide
significant performance improvements over eager mode.
Scaled dot product attention is fully composable with torch.compile()
.
To demonstrate this, let’s compile the CausalSelfAttention
module using
torch.compile()
and observe the resulting performance improvements.
batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
embed_dimension, device=device, dtype=dtype)
print(
f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")
compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in 403.889 microseconds
The compiled module runs in 488.895 microseconds
The exact execution time is dependent on machine, however the results for mine: The non compiled module runs in 166.616 microseconds The compiled module runs in 166.726 microseconds That is not what we were expecting. Let’s dig a little deeper. PyTorch comes with an amazing built-in profiler that you can use to inspect the performance characteristics of your code.
from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
activities.append(ProfilerActivity.CUDA)
with profile(activities=activities, record_shapes=False) as prof:
with record_function(" Non-Compilied Causal Attention"):
for _ in range(25):
model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with profile(activities=activities, record_shapes=False) as prof:
with record_function("Compiled Causal Attention"):
for _ in range(25):
compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Non-Compilied Causal Attention 0.00% 0.000us 0.00% 0.000us 0.000us 10.237ms 50.34% 10.237ms 10.237ms 1
Non-Compilied Causal Attention 21.11% 2.260ms 78.35% 8.390ms 8.390ms 0.000us 0.00% 10.100ms 10.100ms 1
aten::linear 1.07% 114.574us 27.94% 2.992ms 59.840us 0.000us 0.00% 7.758ms 155.155us 50
aten::matmul 2.38% 255.198us 23.94% 2.563ms 51.266us 0.000us 0.00% 7.758ms 155.155us 50
aten::mm 14.43% 1.545ms 19.17% 2.053ms 41.055us 7.758ms 38.15% 7.758ms 155.155us 50
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 5.540ms 27.24% 5.540ms 221.599us 25
aten::scaled_dot_product_attention 2.25% 241.179us 19.82% 2.122ms 84.888us 0.000us 0.00% 2.342ms 93.694us 25
aten::_scaled_dot_product_flash_attention 3.24% 347.161us 17.57% 1.881ms 75.241us 0.000us 0.00% 2.342ms 93.694us 25
aten::_flash_attention_forward 4.02% 430.066us 12.50% 1.339ms 53.562us 2.342ms 11.52% 2.342ms 93.694us 25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::... 0.00% 0.000us 0.00% 0.000us 0.000us 2.342ms 11.52% 2.342ms 93.694us 25
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 10.708ms
Self CUDA time total: 20.337ms
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Compiled Causal Attention 0.00% 0.000us 0.00% 0.000us 0.000us 10.177ms 50.20% 10.177ms 10.177ms 1
Compiled Causal Attention 7.84% 847.066us 73.07% 7.892ms 7.892ms 0.000us 0.00% 10.096ms 10.096ms 1
Torch-Compiled Region 7.20% 777.269us 63.79% 6.889ms 275.568us 0.000us 0.00% 10.096ms 403.858us 25
CompiledFunction 26.77% 2.891ms 55.67% 6.012ms 240.487us 0.000us 0.00% 10.096ms 403.858us 25
aten::mm 8.06% 870.686us 12.96% 1.399ms 27.987us 7.757ms 38.26% 7.757ms 155.141us 50
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 5.541ms 27.33% 5.541ms 221.658us 25
aten::_scaled_dot_product_flash_attention 2.22% 240.064us 15.95% 1.722ms 68.885us 0.000us 0.00% 2.339ms 93.576us 25
aten::_flash_attention_forward 4.02% 434.103us 11.91% 1.286ms 51.435us 2.339ms 11.54% 2.339ms 93.576us 25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::... 0.00% 0.000us 0.00% 0.000us 0.000us 2.339ms 11.54% 2.339ms 93.576us 25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3... 0.00% 0.000us 0.00% 0.000us 0.000us 2.216ms 10.93% 2.216ms 88.625us 25
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 10.800ms
Self CUDA time total: 20.274ms
The previous code snippet generates a report of the top 10 PyTorch functions
that consumed the most GPU execution time, for both the compiled and non-compiled module.
The analysis reveals that the majority of time spent on the GPU is concentrated
on the same set of functions for both modules.
The reason for this here is that torch.compile
is very good at removing the
framework overhead associated with PyTorch. If your model is launching
large, efficient CUDA kernels, which in this case CausalSelfAttention
is, then the overhead of PyTorch can be hidden.
In reality, your module does not normally consist of a singular
CausalSelfAttention
block. When experimenting with Andrej Karpathy NanoGPT repository, compiling
the module took the time per train step from: 6090.49ms
to
3273.17ms
! This was done on commit: ae3a8d5
of NanoGPT training on
the Shakespeare dataset.
Using SDPA with attn_bias subclasses¶
# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
# The module is named ``torch.nn.attention.bias`` and contains the following two
# utilities for generating causal attention variants:
#
# - ``torch.nn.attention.bias.causal_upper_left``
# - ``torch.nn.attention.bias.causal_lower_right``
#
# .. note::
# The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
# is the same as using ``torch.nn.attention.bias.causal_upper_left``.
#
from torch.nn.attention.bias import causal_lower_right, causal_upper_left
batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)
print(type(upper_left_bias))
print(type(lower_right_bias))
assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)
# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``
# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)
# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.
# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)
assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)
# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
<class 'torch.nn.attention.bias.CausalBias'>
<class 'torch.nn.attention.bias.CausalBias'>
tensor([[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False]])
tensor([[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, True]])
Conclusion¶
In this tutorial, we have demonstrated the basic usage of
torch.nn.functional.scaled_dot_product_attention
. We have shown how
the sdpa_kernel
context manager can be used to assert a certain
implementation is used on GPU. As well, we built a simple
CausalSelfAttention
module that works with NestedTensor
and is torch
compilable. In the process we have shown how to the profiling tools can
be used to explore the performance characteristics of a user defined
module.
Total running time of the script: ( 0 minutes 6.318 seconds)