• Tutorials >
  • Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile()

Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile()

Author: Mikayla Gawarecki


This tutorial currently requires you to use the PyTorch nightly build.

What you will learn
  • Learn about the low-level building blocks PyTorch provides to build custom transformer layers ( nested tensors, scaled_dot_product_attention, torch.compile(), and FlexAttention)

  • Discover how the above improve memory usage and performance using MultiHeadAttention as an example

  • Explore advanced customizations using the aforementioned building blocks

  • PyTorch v.2.6.0 or later

Over the past few years, the PyTorch team has developed various lower level features that, when composed, can create a variety of transformer variants. These include:

  • Nested Tensors with the torch.jagged layout (AKA NJTs)

  • scaled_dot_product_attention

  • torch.compile()

  • FlexAttention

This tutorial will give a brief overview of the above technologies and demonstrate how they can be composed to yield flexible and performant transformer layers with improved user experience.

One may observe that the torch.nn module currently provides various Transformer-related layers. In particular, it includes TransformerEncoderLayer, TransformerEncoder, TransformerDecoderLayer, TransformerDecoder, Transformer and MultiheadAttention. This family of layers was initially implemented following the Attention is All You Need paper. The components discussed in this tutorial provide improved user experience, flexibility and performance over the existing nn layers.

Is this tutorial for me?

If you are wondering about what building blocks the torch library provides for writing your own transformer layers and best practices, you are in the right place. Please keep reading!

If you are looking for an out-of-the-box implementation of a popular transformer architecture, note that there are many open-source libraries that provide them, including:

If you are only interested in performant attention score modifications, please check out the FlexAttention blog that contains a gym of masks.

Introducing the Building Blocks

First, we will briefly introduce the four technologies mentioned in the introduction

Nested tensors generalize the shape of regular dense tensors, allowing for representation of ragged-sized data with the same tensor UX. In the context of transformers, we can think of nested tensors as a tool for representing variable sequence lengths. They eliminate the need for the bug-prone practices of explicit padding and masking (think key_padding_mask in nn.MultiHeadAttention).

scaled_dot_product_attention is a primitive for softmax(QKTE+B)V\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V that dispatches into either fused implementations of the operator or a fallback implementation. It works out of the box in eager mode (i.e. the default mode of using PyTorch where operations are executed on the fly as they are encountered) and also integrates seamlessly with torch.compile(). As of 2.6, it will also offer grouped query attention natively.

torch.compile() is a compiler introduced in version 2.0 that is able to capture a graph of PyTorch code and perform various optimizations on it, such as fusing together sequences of ops. Nested tensors with the torch.jagged layout and scaled_dot_product_attention work seamlessly with compile. In the context of transformers, the value add of using compile with nested tensor and SDPA is that compile can remove framework overhead ones sees in eager mode and fuse sequences of ops in transformers together, such as projection and activation.

FlexAttention is a primitive that allows users to modify attention scores prior to the softmax operation. It generalizes the additive B term above for scaled_dot_product_attention, allowing for arbitrary calculation. It requires compile to achieve good performance.

The above building blocks are “All You Need” (as of October 2024)

The main premise in this section is that most transformer variations are GPT-style, consisting of layers like Embedding, Positional Encoding, Attention Blocks and Feed Forward networks. If we were to try to classify the differences in this space, we might land on something like:

  1. Layer type (activation functions such as SwiGLU and others, normalization functions such as RMSNorm and others, positional encodings, such as Sinusoidal, Rotary.)

  2. Layer ordering, such as where to apply norms and positional encoding.

  3. Modifications to attention score, such as ALiBi, Relative Positional Bias and so on.

In a pre-compiler environment, you might write a custom transformer and notice that it functions correctly but is slow. To address this, you might develop a custom fused kernel for the specific series of operations. In a compiler environment, you can simply perform the initial step and then compile and benefit from improved performance.


Remember that MultiheadAttention takes in a query, key, and value, and consists of an input projection, a scaled_dot_product_attention operator and an output projection. The main takeaway we want to demonstrate here is the improvement yielded when we replaced padded/masked inputs with nested tensors. The improvements are threefold:

  • User Experience Remember that nn.MultiheadAttention requires query, key, and value to be dense torch.Tensors. It also provides a key_padding_mask that is used to mask out padding tokens in the key that arise due to different sequence lengths within a batch. Since there is no query_padding_mask in nn.MHA, users have to take care to mask/slice the outputs appropriately to account for query sequence lengths. NestedTensor cleanly removes the need for this sort of error-prone padding masks.

  • Memory Instead of materializing a dense [B, S, D] tensor with a [B, S] padding mask (where B is batch size, S is max sequence length in the batch and D is embedding size), nested tensors allow you to cleanly represent the batch of varying sequence lengths. As a result, the input and intermediate activations will use less memory.

  • Performance Since padding is not materialized and unnecessary computation on padding is skipped, performance and memory usage improve.

We’ll demonstrate the above by building upon the MultiheadAttention layer in the Nested Tensor tutorial and comparing it to the nn.MultiheadAttention layer.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    Computes multi-head attention. Supports nested or padded tensors.

        E_q (int): Size of embedding dim for query
        E_k (int): Size of embedding dim for key
        E_v (int): Size of embedding dim for value
        E_total (int): Total embedding dim of combined heads post input projection. Each head
            has dim E_total // nheads
        nheads (int): Number of heads
        dropout (float, optional): Dropout probability. Default: 0.0
        bias (bool, optional): Whether to add bias to input projection. Default: True

    def __init__(
        E_q: int,
        E_k: int,
        E_v: int,
        E_total: int,
        nheads: int,
        dropout: float = 0.0,
        factory_kwargs = {"device": device, "dtype": dtype}
        self.nheads = nheads
        self.dropout = dropout
        self._qkv_same_embed_dim = E_q == E_k and E_q == E_v
        if self._qkv_same_embed_dim:
            self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)
            self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
            self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs)
            self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs)
        E_out = E_q
        self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads
        self.bias = bias

    def forward(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        Forward pass; runs the following process:
            1. Apply input projection
            2. Split heads and prepare for SDPA
            3. Run SDPA
            4. Apply output projection

            query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``)
            key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``)
            value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``)
            attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None
            is_causal (bool, optional): Whether to apply causal mask. Default: False

            attn_output (torch.Tensor): output of shape (N, L_t, E_q)
        # Step 1. Apply input projection
        if self._qkv_same_embed_dim:
            if query is key and key is value:
                result = self.packed_proj(query)
                query, key, value = torch.chunk(result, 3, dim=-1)
                q_weight, k_weight, v_weight = torch.chunk(
                    self.packed_proj.weight, 3, dim=0
                if self.bias:
                    q_bias, k_bias, v_bias = torch.chunk(
                        self.packed_proj.bias, 3, dim=0
                    q_bias, k_bias, v_bias = None, None, None
                query, key, value = (
                    F.linear(query, q_weight, q_bias),
                    F.linear(key, k_weight, k_bias),
                    F.linear(value, v_weight, v_bias),

            query = self.q_proj(query)
            key = self.k_proj(key)
            value = self.v_proj(value)

        # Step 2. Split heads and prepare for SDPA
        # reshape query, key, value to separate by head
        # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
        query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)

        # Step 3. Run SDPA
        # (N, nheads, L_t, E_head)
        attn_output = F.scaled_dot_product_attention(
            query, key, value, dropout_p=self.dropout, is_causal=is_causal
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        # Step 4. Apply output projection
        # (N, L_t, E_total) -> (N, L_t, E_out)
        attn_output = self.out_proj(attn_output)

        return attn_output


In this section, we include a utility to generate semi-realistic data using Zipf distribution for sentence lengths. This is used to generate the nested query, key, and value tensors. We also include a benchmark utility.

import numpy as np

def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return torch.tensor(sentence_lengths)

# Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
# in the form of nested tensors with the jagged layout.
def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False):
    # generate semi-realistic data using Zipf distribution for sentence lengths
    sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)

    # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
    # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
    # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
    if query_seq_len_1:
        query = torch.nested.nested_tensor(
            [torch.randn(1, E_q, dtype=dtype, device=device) for l in sentence_lengths],
        query = torch.nested.nested_tensor(
                torch.randn(l.item(), E_q, dtype=dtype, device=device)
                for l in sentence_lengths

    key = torch.nested.nested_tensor(
            torch.randn(s.item(), E_k, dtype=dtype, device=device)
            for s in sentence_lengths

    value = torch.nested.nested_tensor(
            torch.randn(s.item(), E_v, dtype=dtype, device=device)
            for s in sentence_lengths

    return query, key, value, sentence_lengths

import math
import timeit

def benchmark(func, *args, **kwargs):
    begin = timeit.default_timer()
    output = func(*args, **kwargs)
    end = timeit.default_timer()
    return output, (end - begin), torch.cuda.max_memory_allocated()

We will now demonstrate the performance improvements of using nested tensors in the MultiheadAttention layer + compile for self attention. We compare this against the traditional nn.MultiheadAttention + compile with padding and masking.

N, E_q, E_k, E_v, E_total = 512, 512, 512, 512, 512
E_out = E_q
d_model = E_q
nheads = 8
dropout = 0.0
bias = True
device = "cuda"
query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)
S = sentence_lengths.max().item()
    f"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}"
padded_query, padded_key, padded_value = (
    t.to_padded_tensor(0.0) for t in (query, key, value)

mha_layer = MultiHeadAttention(
    E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device="cuda"
vanilla_mha_layer = nn.MultiheadAttention(
    E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device="cuda"

# ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :(
mha_layer.out_proj.weight = nn.Parameter(
mha_layer.packed_proj.weight = nn.Parameter(
mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())
mha_layer.packed_proj.bias = nn.Parameter(

new_mha_layer = torch.compile(mha_layer)
# warmup compile
nested_result_warmup = new_mha_layer(query, query, query, is_causal=True)

# benchmark
nested_result, nested_time, nested_peak_memory = benchmark(
    new_mha_layer, query, query, query, is_causal=True
padded_nested_result = nested_result.to_padded_tensor(0.0)

# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask``
# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal``
src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0]
attn_mask = torch.empty((N, S, S), device=device).fill_(float("-inf"))
for i, s in enumerate(sentence_lengths):
    attn_mask[i, :s, :s] = nn.Transformer.generate_square_subsequent_mask(s)
attn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N * nheads, S, S)

vanilla_mha_layer = torch.compile(vanilla_mha_layer)
# warmup compile
warmup_vanilla_result = vanilla_mha_layer(

# benchmark
(padded_result, _), padded_time, padded_peak_memory = benchmark(

print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB")
print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB")
    "Max difference between vanilla and nested result",
    (padded_result - padded_nested_result).abs().max().item(),
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
    f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
Total sequence length in nested query 10436, max sequence length 128
padded_time=0.01603, padded_peak_memory=3.88 GB
nested_time=0.00233, nested_peak_memory=0.93 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 6.88
Nested peak memory reduction 2.96 GB

For reference, here are some sample outputs on A100:

padded_time=0.03454, padded_peak_memory=4.14 GB
nested_time=0.00612, nested_peak_memory=0.76 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 5.65
Nested peak memory reduction 3.39 GB

We can also see the same for backward pass

for i, entry_length in enumerate(sentence_lengths):
    # padding-specific step: remove output projection bias from padded entries for fair comparison
    padded_result[i, entry_length:, :] = 0.0

_, padded_bw_time, padded_bw_peak_mem = benchmark(
    lambda: padded_result.sum().backward()
_, nested_bw_time, nested_bw_peak_mem = benchmark(
    lambda: padded_nested_result.sum().backward()

print(f"{padded_bw_time=:.5f}, padded_bw_peak_mem={padded_bw_peak_mem/1e9:.2f} GB")
print(f"{nested_bw_time=:.5f}, nested_bw_peak_mem={nested_bw_peak_mem/1e9:.2f} GB")
print(f"Nested backward speedup: {(padded_bw_time/nested_bw_time):.2f}")
    f"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB"

    "Difference in out_proj.weight.grad",
    (mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad)
    "Difference in packed_proj.weight.grad",
    (mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad)
    "Difference in out_proj.bias.grad",
    (mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad)
    "Difference in packed_proj.bias.grad",
    (mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad)
padded_bw_time=1.81020, padded_bw_peak_mem=4.69 GB
nested_bw_time=0.11084, nested_bw_peak_mem=3.14 GB
Nested backward speedup: 16.33
Nested backward peak memory reduction 1.55 GB
Difference in out_proj.weight.grad 0.000396728515625
Difference in packed_proj.weight.grad 0.00146484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.0029296875

Sample outputs on A100:

padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
Nested backward speedup: 144.13
Nested backward peak memory reduction 1.86 GB
Difference in out_proj.weight.grad 0.000244140625
Difference in packed_proj.weight.grad 0.001556396484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.001953125

GPT-style layer

A basic GPT-style transformer layer consists of a causal self-attention layer followed by a feed-forward network (FFN) with skip connections. Implementing this is fairly straightforward using the MultiheadAttention layer above and gives equivalent results to an nn.TransformerEncoderLayer with is_causal=True.

We demonstrate examples of implementing the rest of the nn layers here but omit that from this tutorial for brevity.

Going one step further

So far, we have demonstrated how to implement a performant MultiheadAttention layer that follows the traditional nn.MultiheadAttention. Going back to our classification of modifications to the transformer architecture, remember that we classified the modifications into layer type, layer ordering, and modifications to the attention score. We trust that changing layer type and layer ordering (such as swapping LayerNorm for RMSNorm) is fairly straightforward.

In this section, we will discuss various functionalities using the aforementioned building blocks, including the following:

  • Cross Attention

  • Fully masked rows no longer cause NaNs

  • Modifying attention score: ALiBi with FlexAttention and NJT

  • Packed Projection

Cross Attention

Cross attention is a form of attention where the query and key/value tensors are from different sequences.

One example of this is in nn.TransformerDecoderLayer where the query comes from the decoder and the key/value come from the encoder.

The above MultiheadAttention layer nicely generalizes to this case with nested tensors for both query and key/value.

query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)

    f"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}"
    f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}"
out = new_mha_layer(query, key, value, is_causal=False)
Total sequence length in nested query 10617, max sequence length 165
Total sequence length in nested key/value 10176, max sequence length 137

As above, we can compare this against the vanilla compiled nn.MultiheadAttention.

query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)
padded_query, padded_key, padded_value = (
    t.to_padded_tensor(0.0) for t in (query, key, value)

key_padding_mask = torch.where(padded_key == 0.0, -math.inf, 0)[:, :, 0]

# warmup compile
warmup_nested_result = new_mha_layer(query, key, value, is_causal=False)
warmup_vanilla_result = vanilla_mha_layer(

nested_result, nested_time, nested_peak_memory = benchmark(
    new_mha_layer, query, key, value, is_causal=False
(padded_result, _), padded_time, padded_peak_memory = benchmark(
padded_nested_result = nested_result.to_padded_tensor(0.0)
for i, entry_length in enumerate(q_len):
    # padding-specific step: remove output projection bias from padded entries for fair comparison
    padded_result[i, entry_length:, :] = 0.0

    "Max difference between vanilla and nested result",
    (padded_result - padded_nested_result).abs().max().item(),
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
    f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
Max difference between vanilla and nested result 0.0
Nested speedup: 4.87
Nested peak memory reduction 1.20 GB

Sample outputs on A100:

Max difference between vanilla and nested result 0.0
Nested speedup: 4.01
Nested peak memory reduction 1.40 GB

Fully masked rows no longer cause NaNs

There has been a long standing issue with nn.MultiheadAttention and scaled_dot_product_attention where if a row was fully masked out, the output of the attention layer would be NaN. See issue. This is because the softmax over an empty set is undefined.

Thanks to this PR this is no longer the case. Instead, the output corresponding to fully masked rows in scaled_dot_product_attention will be 0. For cases where nn.MHA does not employ the “fast-path”, this will also apply.

Using a custom MHA layer with NJTs is strongly recommended over the existing “fast-path” in nn.MultiheadAttention as NJT’s ability to model raggedness appropriately makes it possible to properly express empty sequences.

FlexAttention + NJT

NJT also composes with the FlexAttention module. This is a generalization of the MultiheadAttention layer that allows for arbitrary modifications to the attention score. The example below takes the alibi_mod that implements ALiBi from attention gym and uses it with nested input tensors.

from torch.nn.attention.flex_attention import flex_attention

def generate_alibi_bias(H: int):
    """Returns an alibi bias score_mod given the number of heads H
        H: number of heads
        alibi_bias: alibi bias score_mod

    def alibi_mod(score, b, h, q_idx, kv_idx):
        scale = torch.exp2(-((h + 1) * 8.0 / H))
        bias = (q_idx - kv_idx) * scale
        return score + bias

    return alibi_mod

query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
n_heads, D = 8, E_q // 8
alibi_score_mod = generate_alibi_bias(n_heads)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)

In addition, one can also use the block_mask utility of FlexAttention with NJTs via the create_nested_block_mask function. This is useful for taking advantage of the sparsity of the mask to speed up the attention computation. In particular, the function creates a sparse block mask for a “stacked sequence” of all the variable length sequences in the NJT combined into one, while properly masking out inter-sequence attention. In the following example, we show how to create a causal block mask using this utility.

from torch.nn.attention.flex_attention import create_nested_block_mask

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex = flex_attention(query, key, value, block_mask=block_mask)

Packed Projection

Packed projection is a technique that makes use of the fact that when the input for projection (matrix multiplications) are the same (self-attention), we can pack the projection weights and biases into single tensors. It is especially useful when the individual projections are memory bound rather than compute bound. There are two examples that we will demonstrate here:

  • Input projection for MultiheadAttention

  • SwiGLU activation in feed-forward network of Transformer Layer

Input projection for MultiheadAttention

When doing self-attention, the query, key, and value are the same tensor. Each of these tensors is projected with a Linear(E_q, E_total) layer. Instead, we can pack this into one layer, which is what we do in the MultiheadAttention layer above.

Let us compare the performance of the packed projection against the usual method:

class InputProjection(nn.Module):
    def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
        self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
        self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)

    def forward(self, x):
        return self.q_proj(x), self.k_proj(x), self.v_proj(x)

class PackedInputProjection(nn.Module):
    def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)

    def forward(self, query):
        return torch.chunk(self.packed_proj(query), 3, dim=-1)

B, D, dtype = 256, 8192, torch.bfloat16

in_proj = torch.compile(InputProjection(D, D, device="cuda", dtype=torch.bfloat16))
packed_in_proj = torch.compile(
    PackedInputProjection(D, D, device="cuda", dtype=torch.bfloat16)

q, _, _, sequence_lengths = gen_batch(B, D, D, D, device="cuda", dtype=torch.bfloat16)

# warmup

# benchmark
(q_out, k_out, v_out), time, _ = benchmark(in_proj, q)
(q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q)
# On my A100 prints 1.05x speedup
    f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x"
InputProjection: 0.034046 s, PackedInputProjection: 0.032765 s, speedup: 1.04x

SwiGLU feed forward network of Transformer Layer

Swish-Gated Linear Unit (SwiGLU) is a non-linear activation function that is increasingly popular in the feed-forward network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as:

class SwiGLUFFN(nn.Module):
    def __init__(
        factory_kwargs = {"device": device, "dtype": dtype}
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

An alternative way of implementing this that uses packed projection is

class PackedSwiGLUFFN(nn.Module):
    def __init__(
        factory_kwargs = {"device": device, "dtype": dtype}
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False, **factory_kwargs)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)

    def forward(self, x):
        x1, x3 = torch.chunk(self.w13(x), 2, dim=-1)
        return self.w2(F.silu(x1) * x3)

We can compare the performance of the two implementations as follows Depending on your hardware, you might see different results. On an A100 I see 1.12x speedup for D=128.

D = 128

swigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16))
packed_swigluffn = torch.compile(
    PackedSwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16)

q, _, _, sentence_lengths = gen_batch(D, D, D, D, device="cuda", dtype=torch.bfloat16)

# warmup

# benchmark
_, time, _ = benchmark(swigluffn, q)
_, time_packed, _ = benchmark(packed_swigluffn, q)
# On my A100 prints 1.08x speedup
    f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x"
SwiGLUFFN: 0.0010763959999167128 s, PackedSwiGLUFFN: 0.0009788050001589 s, speedup: 1.10x

Extended examples

We intend to update this tutorial to demonstrate more examples of how to use the various performant building blocks such as KV-Caching, Grouped Query Attention etc. Further, there are several good examples of using various performant building blocks to implement various transformer architectures. Some examples include


In this tutorial, we have introduced the low level building blocks PyTorch provides for writing transformer layers and demonstrated examples how to compose them. It is our hope that this tutorial has educated the reader on the ease with which flexible and performant transformer layers can be implemented by users of PyTorch.

Total running time of the script: ( 1 minutes 21.367 seconds)

Gallery generated by Sphinx-Gallery


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources