Note
Click here to download the full example code
(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)¶
Author: Driss Guessous
Summary¶
In this tutorial, we want to highlight a new torch.nn.functional
function
that can be helpful for implementing transformer architectures. The
function is named torch.nn.functional.scaled_dot_product_attention
.
For detailed description of the function, see the PyTorch documentation.
This function has already been incorporated into torch.nn.MultiheadAttention
and torch.nn.TransformerEncoderLayer
.
Overview¶
At a high level, this PyTorch function calculates the scaled dot product attention (SDPA) between query, key, and value according to the definition found in the paper Attention is all you need. While this function can be written in PyTorch using existing functions, a fused implementation can provide large performance benefits over a naive implementation.
Fused implementations¶
For CUDA tensor inputs, the function will dispatch into one of the following implementations:
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
A PyTorch implementation defined in C++
Note
This tutorial requires PyTorch 2.0.0 or later.
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[-1.3321, -0.3489, 0.3015, -0.3912, 0.9867, 0.3137, -0.0691,
-1.2593],
[-1.0882, 0.2506, 0.6491, 0.1360, 0.5238, -0.2448, -0.0820,
-0.6171],
[-1.0012, 0.3990, 0.6441, -0.0277, 0.5325, -0.2564, -0.0607,
-0.6404]],
[[ 0.6091, 0.0708, 0.6188, 0.3252, -0.1598, 0.4197, -0.2335,
0.0630],
[ 0.5285, 0.3890, -0.2649, 0.3706, -0.3839, 0.1963, -0.6242,
0.2312],
[ 0.4048, 0.0762, 0.3777, 0.4689, -0.2978, 0.2754, -0.6429,
0.1037]]], device='cuda:0')
Explicit Dispatcher Control¶
While the function will implicitly dispatch to one of the three implementations, the user can also explicitly control the dispatch via the use of a context manager. This context manager allows users to explicitly disable certain implementations. If a user wants to ensure the function is indeed using the fastest implementation for their specific inputs, the context manager can be used to sweep through measuring performance.
# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# Lets explore the speed of each of the 3 implementations
from torch.backends.cuda import sdp_kernel, SDPBackend
# Helpful arguments mapper
backend_map = {
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}
with sdp_kernel(**backend_map[SDPBackend.MATH]):
print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
try:
print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 4878.902 microseconds
The math implementation runs in 19252.439 microseconds
The flash attention implementation runs in 4880.070 microseconds
The memory efficient implementation runs in 4831.269 microseconds
Hardware dependence¶
Depending on what machine you ran the above cell on and what hardware is available, your results might be different. - If you don’t have a GPU and are running on CPU then the context manager will have no effect and all three runs should return similar timings. - Depending on what compute capability your graphics card supports flash attention or memory efficient might have failed.
Causal Self Attention¶
Below is an example implementation of a multi-headed causal self attention block inspired by Andrej Karpathy NanoGPT repository.
class CausalSelfAttention(nn.Module):
def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
super().__init__()
assert embed_dimension % num_heads == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
# output projection
self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
# regularization
self.dropout = dropout
self.resid_dropout = nn.Dropout(dropout)
self.num_heads = num_heads
self.embed_dimension = embed_dimension
# Perform causal masking
self.is_causal = is_causal
def forward(self, x):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query_projected = self.c_attn(x)
batch_size = query_projected.size(0)
embed_dim = query_projected.size(2)
head_dim = embed_dim // (self.num_heads * 3)
query, key, value = query_projected.chunk(3, -1)
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
if self.training:
dropout = self.dropout
is_causal = self.is_causal
else:
dropout = 0.0
is_causal = False
y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)
y = self.resid_dropout(self.c_proj(y))
return y
num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
(c_attn): Linear(in_features=512, out_features=1536, bias=False)
(c_proj): Linear(in_features=512, out_features=512, bias=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
NestedTensor
and Dense tensor support¶
SDPA supports both NestedTensor
and Dense tensor inputs. NestedTensors
handle the case where the input is a batch of variable length sequences
without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors
see
torch.nested and NestedTensors Tutorial.
import random
def generate_rand_batch(
batch_size,
max_sequence_len,
embed_dimension,
pad_percentage=None,
dtype=torch.float16,
device="cuda",
):
if not pad_percentage:
return (
torch.randn(
batch_size,
max_sequence_len,
embed_dimension,
dtype=dtype,
device=device,
),
None,
)
# Random sequence lengths
seq_len_list = [
int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
for _ in range(batch_size)
]
# Make random entry in the batch have max sequence length
seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
return (
torch.nested.nested_tensor(
[
torch.randn(seq_len, embed_dimension,
dtype=dtype, device=device)
for seq_len in seq_len_list
]
),
seq_len_list,
)
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
/var/lib/jenkins/workspace/intermediate_source/scaled_dot_product_attention_tutorial.py:226: UserWarning:
The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:177.)
/var/lib/jenkins/workspace/intermediate_source/scaled_dot_product_attention_tutorial.py:174: UserWarning:
Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:545.)
/var/lib/jenkins/workspace/intermediate_source/scaled_dot_product_attention_tutorial.py:174: UserWarning:
Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:338.)
/var/lib/jenkins/workspace/intermediate_source/scaled_dot_product_attention_tutorial.py:174: UserWarning:
Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:547.)
/var/lib/jenkins/workspace/intermediate_source/scaled_dot_product_attention_tutorial.py:174: UserWarning:
We are not enabling nested Tensors for Flash Attention because of cuda memory errors. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:155.)
FlashAttention is not supported. See warnings for reasons.
Using SDPA with torch.compile
¶
With the release of PyTorch 2.0, a new feature called
torch.compile()
has been introduced, which can provide
significant performance improvements over eager mode.
Scaled dot product attention is fully composable with torch.compile()
.
To demonstrate this, let’s compile the CausalSelfAttention
module using
torch.compile()
and observe the resulting performance improvements.
batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
embed_dimension, device=device, dtype=dtype)
print(
f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")
compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in 497.354 microseconds
The compiled module runs in 497.380 microseconds
The exact execution time is dependent on machine, however the results for mine: The non compiled module runs in 166.616 microseconds The compiled module runs in 166.726 microseconds That is not what we were expecting. Let’s dig a little deeper. PyTorch comes with an amazing built-in profiler that you can use to inspect the performance characteristics of your code.
from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
activities.append(ProfilerActivity.CUDA)
with profile(activities=activities, record_shapes=False) as prof:
with record_function(" Non-Compilied Causal Attention"):
for _ in range(25):
model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with profile(activities=activities, record_shapes=False) as prof:
with record_function("Compiled Causal Attention"):
for _ in range(25):
compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
# ::
#
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Non-Compilied Causal Attention 15.64% 2.055ms 69.88% 9.182ms 9.182ms 0.000us 0.00% 12.465ms 12.465ms 1
aten::matmul 1.93% 254.000us 21.11% 2.774ms 55.480us 0.000us 0.00% 7.798ms 155.960us 50
aten::mm 13.49% 1.772ms 16.92% 2.223ms 44.460us 7.798ms 62.56% 7.798ms 155.960us 50
aten::linear 1.76% 231.000us 24.00% 3.154ms 63.080us 0.000us 0.00% 7.620ms 152.400us 50
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 5.585ms 44.81% 5.585ms 223.400us 25
aten::scaled_dot_product_attention 1.34% 176.000us 22.64% 2.975ms 119.000us 0.000us 0.00% 4.667ms 186.680us 25
aten::_scaled_dot_product_flash_attention 4.41% 580.000us 21.30% 2.799ms 111.960us 0.000us 0.00% 4.667ms 186.680us 25
aten::_flash_attention_forward 2.63% 345.000us 6.36% 836.000us 33.440us 4.567ms 36.64% 4.567ms 182.680us 25
void fmha_fwd_loop_kernel<FMHA_kernel_traits<256, 64... 0.00% 0.000us 0.00% 0.000us 0.000us 4.567ms 36.64% 4.567ms 182.680us 25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3... 0.00% 0.000us 0.00% 0.000us 0.000us 2.213ms 17.75% 2.213ms 88.520us 25
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 13.139ms
Self CUDA time total: 12.465ms
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Compiled Causal Attention 16.25% 2.168ms 85.81% 11.450ms 11.450ms 0.000us 0.00% 12.475ms 12.475ms 1
CompiledFunction 28.43% 3.793ms 68.97% 9.202ms 368.080us 0.000us 0.00% 12.475ms 499.000us 25
aten::mm 9.10% 1.214ms 13.03% 1.738ms 34.760us 7.808ms 62.59% 7.808ms 156.160us 50
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 5.587ms 44.79% 5.587ms 223.480us 25
aten::_scaled_dot_product_flash_attention 3.29% 439.000us 20.27% 2.705ms 108.200us 0.000us 0.00% 4.667ms 186.680us 25
aten::_flash_attention_forward 2.43% 324.000us 6.24% 832.000us 33.280us 4.567ms 36.61% 4.567ms 182.680us 25
void fmha_fwd_loop_kernel<FMHA_kernel_traits<256, 64... 0.00% 0.000us 0.00% 0.000us 0.000us 4.567ms 36.61% 4.567ms 182.680us 25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3... 0.00% 0.000us 0.00% 0.000us 0.000us 2.221ms 17.80% 2.221ms 88.840us 25
aten::arange 3.14% 419.000us 13.54% 1.807ms 18.070us 100.000us 0.80% 198.000us 1.980us 100
void (anonymous namespace)::elementwise_kernel_with_... 0.00% 0.000us 0.00% 0.000us 0.000us 100.000us 0.80% 100.000us 2.000us 50
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 13.343ms
Self CUDA time total: 12.475ms
The previous code snippet generates a report of the top 10 PyTorch functions
that consumed the most GPU execution time, for both the compiled and non-compiled module.
The analysis reveals that the majority of time spent on the GPU is concentrated
on the same set of functions for both modules.
The reason for this here is that torch.compile
is very good at removing the
framework overhead associated with PyTorch. If your model is launching
large, efficient CUDA kernels, which in this case CausalSelfAttention
is, then the overhead of PyTorch can be hidden.
In reality, your module does not normally consist of a singular
CausalSelfAttention
block. When experimenting with Andrej Karpathy NanoGPT repository, compiling
the module took the time per train step from: 6090.49ms
to
3273.17ms
! This was done on commit: ae3a8d5
of NanoGPT training on
the Shakespeare dataset.
Conclusion¶
In this tutorial, we have demonstrated the basic usage of
torch.nn.functional.scaled_dot_product_attention
. We have shown how
the sdp_kernel
context manager can be used to assert a certain
implementation is used on GPU. As well, we built a simple
CausalSelfAttention
module that works with NestedTensor
and is torch
compilable. In the process we have shown how to the profiling tools can
be used to explore the performance characteristics of a user defined
module.
Total running time of the script: ( 0 minutes 6.375 seconds)