• Tutorials >
  • Fusing Convolution and Batch Norm using Custom Function

Fusing Convolution and Batch Norm using Custom Function

Fusing adjacent convolution and batch norm layers together is typically an inference-time optimization to improve run-time. It is usually achieved by eliminating the batch norm layer entirely and updating the weight and bias of the preceding convolution [0]. However, this technique is not applicable for training models.

In this tutorial, we will show a different technique to fuse the two layers that can be applied during training. Rather than improved runtime, the objective of this optimization is to reduce memory usage.

The idea behind this optimization is to see that both convolution and batch norm (as well as many other ops) need to save a copy of their input during forward for the backward pass. For large batch sizes, these saved inputs are responsible for most of your memory usage, so being able to avoid allocating another input tensor for every convolution batch norm pair can be a significant reduction.

In this tutorial, we avoid this extra allocation by combining convolution and batch norm into a single layer (as a custom function). In the forward of this combined layer, we perform normal convolution and batch norm as-is, with the only difference being that we will only save the inputs to the convolution. To obtain the input of batch norm, which is necessary to backward through it, we recompute convolution forward again during the backward pass.

It is important to note that the usage of this optimization is situational. Though (by avoiding one buffer saved) we always reduce the memory allocated at the end of the forward pass, there are cases when the peak memory allocated may not actually be reduced. See the final section for more details.

For simplicity, in this tutorial we hardcode bias=False, stride=1, padding=0, dilation=1, and groups=1 for Conv2D. For BatchNorm2D, we hardcode eps=1e-3, momentum=0.1, affine=False, and track_running_statistics=False. Another small difference is that we add epsilon in the denominator outside of the square root in the computation of batch norm.

[0] https://nenadmarkus.com/p/fusing-batchnorm-and-conv/

Backward Formula Implementation for Convolution

Implementing a custom function requires us to implement the backward ourselves. In this case, we need both the backward formulas for Conv2D and BatchNorm2D. Eventually we’d chain them together in our unified backward function, but below we first implement them as their own custom functions so we can validate their correctness individually

import torch
from torch.autograd.function import once_differentiable
import torch.nn.functional as F

def convolution_backward(grad_out, X, weight):
    grad_input = F.conv2d(X.transpose(0, 1), grad_out.transpose(0, 1)).transpose(0, 1)
    grad_X = F.conv_transpose2d(grad_out, weight)
    return grad_X, grad_input

class Conv2D(torch.autograd.Function):
    def forward(ctx, X, weight):
        ctx.save_for_backward(X, weight)
        return F.conv2d(X, weight)

    # Use @once_differentiable by default unless we intend to double backward
    def backward(ctx, grad_out):
        X, weight = ctx.saved_tensors
        return convolution_backward(grad_out, X, weight)

When testing with gradcheck, it is important to use double precision

weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)
X = torch.rand(10, 3, 7, 7, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(Conv2D.apply, (X, weight))

Backward Formula Implementation for Batch Norm

Batch Norm has two modes: training and eval mode. In training mode the sample statistics are a function of the inputs. In eval mode, we use the saved running statistics, which are not a function of the inputs. This makes non-training mode’s backward significantly simpler. Below we implement and test only the training mode case.

def unsqueeze_all(t):
    # Helper function to ``unsqueeze`` all the dimensions that we reduce over
    return t[None, :, None, None]

def batch_norm_backward(grad_out, X, sum, sqrt_var, N, eps):
    # We use the formula: ``out = (X - mean(X)) / (sqrt(var(X)) + eps)``
    # in batch norm 2D forward. To simplify our derivation, we follow the
    # chain rule and compute the gradients as follows before accumulating
    # them all into a final grad_input.
    #  1) ``grad of out wrt var(X)`` * ``grad of var(X) wrt X``
    #  2) ``grad of out wrt mean(X)`` * ``grad of mean(X) wrt X``
    #  3) ``grad of out wrt X in the numerator`` * ``grad of X wrt X``
    # We then rewrite the formulas to use as few extra buffers as possible
    tmp = ((X - unsqueeze_all(sum) / N) * grad_out).sum(dim=(0, 2, 3))
    tmp *= -1
    d_denom = tmp / (sqrt_var + eps)**2  # ``d_denom = -num / denom**2``
    # It is useful to delete tensors when you no longer need them with ``del``
    # For example, we could've done ``del tmp`` here because we won't use it later
    # In this case, it's not a big difference because ``tmp`` only has size of (C,)
    # The important thing is avoid allocating NCHW-sized tensors unnecessarily
    d_var = d_denom / (2 * sqrt_var)  # ``denom = torch.sqrt(var) + eps``
    # Compute ``d_mean_dx`` before allocating the final NCHW-sized grad_input buffer
    d_mean_dx = grad_out / unsqueeze_all(sqrt_var + eps)
    d_mean_dx = unsqueeze_all(-d_mean_dx.sum(dim=(0, 2, 3)) / N)
    # ``d_mean_dx`` has already been reassigned to a C-sized buffer so no need to worry

    # ``(1) unbiased_var(x) = ((X - unsqueeze_all(mean))**2).sum(dim=(0, 2, 3)) / (N - 1)``
    grad_input = X * unsqueeze_all(d_var * N)
    grad_input += unsqueeze_all(-d_var * sum)
    grad_input *= 2 / ((N - 1) * N)
    # (2) mean (see above)
    grad_input += d_mean_dx
    # (3) Add 'grad_out / <factor>' without allocating an extra buffer
    grad_input *= unsqueeze_all(sqrt_var + eps)
    grad_input += grad_out
    grad_input /= unsqueeze_all(sqrt_var + eps)  # ``sqrt_var + eps > 0!``
    return grad_input

class BatchNorm(torch.autograd.Function):
    def forward(ctx, X, eps=1e-3):
        # Don't save ``keepdim`` values for backward
        sum = X.sum(dim=(0, 2, 3))
        var = X.var(unbiased=True, dim=(0, 2, 3))
        N = X.numel() / X.size(1)
        sqrt_var = torch.sqrt(var)
        ctx.eps = eps
        ctx.sum = sum
        ctx.N = N
        ctx.sqrt_var = sqrt_var
        mean = sum / N
        denom = sqrt_var + eps
        out = X - unsqueeze_all(mean)
        out /= unsqueeze_all(denom)
        return out

    def backward(ctx, grad_out):
        X, = ctx.saved_tensors
        return batch_norm_backward(grad_out, X, ctx.sum, ctx.sqrt_var, ctx.N, ctx.eps)

Testing with gradcheck

a = torch.rand(1, 2, 3, 4, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(BatchNorm.apply, (a,), fast_mode=False)

Fusing Convolution and BatchNorm

Now that the bulk of the work has been done, we can combine them together. Note that in (1) we only save a single buffer for backward, but this also means we recompute convolution forward in (5). Also see that in (2), (3), (4), and (6), it’s the same exact code as the examples above.

class FusedConvBN2DFunction(torch.autograd.Function):
    def forward(ctx, X, conv_weight, eps=1e-3):
        assert X.ndim == 4  # N, C, H, W
        # (1) Only need to save this single buffer for backward!
        ctx.save_for_backward(X, conv_weight)

        # (2) Exact same Conv2D forward from example above
        X = F.conv2d(X, conv_weight)
        # (3) Exact same BatchNorm2D forward from example above
        sum = X.sum(dim=(0, 2, 3))
        var = X.var(unbiased=True, dim=(0, 2, 3))
        N = X.numel() / X.size(1)
        sqrt_var = torch.sqrt(var)
        ctx.eps = eps
        ctx.sum = sum
        ctx.N = N
        ctx.sqrt_var = sqrt_var
        mean = sum / N
        denom = sqrt_var + eps
        # Try to do as many things in-place as possible
        # Instead of `out = (X - a) / b`, doing `out = X - a; out /= b`
        # avoids allocating one extra NCHW-sized buffer here
        out = X - unsqueeze_all(mean)
        out /= unsqueeze_all(denom)
        return out

    def backward(ctx, grad_out):
        X, conv_weight, = ctx.saved_tensors
        # (4) Batch norm backward
        # (5) We need to recompute conv
        X_conv_out = F.conv2d(X, conv_weight)
        grad_out = batch_norm_backward(grad_out, X_conv_out, ctx.sum, ctx.sqrt_var,
                                       ctx.N, ctx.eps)
        # (6) Conv2d backward
        grad_X, grad_input = convolution_backward(grad_out, X, conv_weight)
        return grad_X, grad_input, None, None, None, None, None

The next step is to wrap our functional variant in a stateful nn.Module

import torch.nn as nn
import math

class FusedConvBN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, exp_avg_factor=0.1,
                 eps=1e-3, device=None, dtype=None):
        super(FusedConvBN, self).__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        # Conv parameters
        weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
        self.conv_weight = nn.Parameter(torch.empty(*weight_shape, **factory_kwargs))
        # Batch norm parameters
        num_features = out_channels
        self.num_features = num_features
        self.eps = eps
        # Initialize

    def forward(self, X):
        return FusedConvBN2DFunction.apply(X, self.conv_weight, self.eps)

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.conv_weight, a=math.sqrt(5))

Use gradcheck to validate the correctness of our backward formula

weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)
X = torch.rand(2, 3, 4, 4, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(FusedConvBN2DFunction.apply, (X, weight))

Testing out our new Layer

Use FusedConvBN to train a basic network The code below is after some light modifications to the example here: https://github.com/pytorch/examples/tree/master/mnist

import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

# Record memory allocated at the end of the forward pass
memory_allocated = [[],[]]

class Net(nn.Module):
    def __init__(self, fused=True):
        super(Net, self).__init__()
        self.fused = fused
        if fused:
            self.convbn1 = FusedConvBN(1, 32, 3)
            self.convbn2 = FusedConvBN(32, 64, 3)
            self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
            self.bn1 = nn.BatchNorm2d(32, affine=False, track_running_stats=False)
            self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
            self.bn2 = nn.BatchNorm2d(64, affine=False, track_running_stats=False)
        self.fc1 = nn.Linear(9216, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        if self.fused:
            x = self.convbn1(x)
            x = self.conv1(x)
            x = self.bn1(x)
        if self.fused:
            x = self.convbn2(x)
            x = self.conv2(x)
            x = self.bn2(x)
        x = F.max_pool2d(x, 2)
        x = x.flatten(1)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        if fused:
        return output

def train(model, device, train_loader, optimizer, epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = F.nll_loss(output, target)
        if batch_idx % 2 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    test_loss = 0
    correct = 0
    # Use inference mode instead of no_grad, for free improved test-time performance
    with torch.inference_mode():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_kwargs = {'batch_size': 2048}
test_kwargs = {'batch_size': 2048}

if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True,
                   'shuffle': True}

transform = transforms.Compose([
    transforms.Normalize((0.1307,), (0.3081,))
dataset1 = datasets.MNIST('../data', train=True, download=True,
dataset2 = datasets.MNIST('../data', train=False,
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/9912422 [00:00<?, ?it/s]
 90%|########9 | 8912896/9912422 [00:00<00:00, 89115659.49it/s]
100%|##########| 9912422/9912422 [00:00<00:00, 94070995.50it/s]
Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s]
100%|##########| 28881/28881 [00:00<00:00, 86587343.69it/s]
Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/1648877 [00:00<?, ?it/s]
100%|##########| 1648877/1648877 [00:00<00:00, 25472615.02it/s]
Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/4542 [00:00<?, ?it/s]
100%|##########| 4542/4542 [00:00<00:00, 30976469.54it/s]
Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw

A Comparison of Memory Usage

If CUDA is enabled, print out memory usage for both fused=True and fused=False For an example run on NVIDIA GeForce RTX 3070, NVIDIA CUDA® Deep Neural Network library (cuDNN) 8.0.5: fused peak memory: 1.56GB, unfused peak memory: 2.68GB

It is important to note that the peak memory usage for this model may vary depending the specific cuDNN convolution algorithm used. For shallower models, it may be possible for the peak memory allocated of the fused model to exceed that of the unfused model! This is because the memory allocated to compute certain cuDNN convolution algorithms can be high enough to “hide” the typical peak you would expect to be near the start of the backward pass.

For this reason, we also record and display the memory allocated at the end of the forward pass as an approximation, and to demonstrate that we indeed allocate one fewer buffer per fused conv-bn pair.

from statistics import mean

torch.backends.cudnn.enabled = True

if use_cuda:
    peak_memory_allocated = []

    for fused in (True, False):

        model = Net(fused=fused).to(device)
        optimizer = optim.Adadelta(model.parameters(), lr=1.0)
        scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

        for epoch in range(1):
            train(model, device, train_loader, optimizer, epoch)
            test(model, device, test_loader)
    print("cuDNN version:", torch.backends.cudnn.version())
    print("Peak memory allocated:")
    print(f"fused: {peak_memory_allocated[0]/1024**3:.2f}GB, unfused: {peak_memory_allocated[1]/1024**3:.2f}GB")
    print("Memory allocated at end of forward pass:")
    print(f"fused: {mean(memory_allocated[0])/1024**3:.2f}GB, unfused: {mean(memory_allocated[1])/1024**3:.2f}GB")
Train Epoch: 0 [0/60000 (0%)]   Loss: 2.352060
Train Epoch: 0 [4096/60000 (7%)]        Loss: 7.321198
Train Epoch: 0 [8192/60000 (13%)]       Loss: 4.253123
Train Epoch: 0 [12288/60000 (20%)]      Loss: 2.916881
Train Epoch: 0 [16384/60000 (27%)]      Loss: 2.643569
Train Epoch: 0 [20480/60000 (33%)]      Loss: 1.819675
Train Epoch: 0 [24576/60000 (40%)]      Loss: 1.546603
Train Epoch: 0 [28672/60000 (47%)]      Loss: 1.501737
Train Epoch: 0 [32768/60000 (53%)]      Loss: 1.495718
Train Epoch: 0 [36864/60000 (60%)]      Loss: 1.421847
Train Epoch: 0 [40960/60000 (67%)]      Loss: 1.260746
Train Epoch: 0 [45056/60000 (73%)]      Loss: 1.199898
Train Epoch: 0 [49152/60000 (80%)]      Loss: 0.951574
Train Epoch: 0 [53248/60000 (87%)]      Loss: 0.846121
Train Epoch: 0 [57344/60000 (93%)]      Loss: 0.794893

Test set: Average loss: 0.4699, Accuracy: 8564/10000 (86%)

Train Epoch: 0 [0/60000 (0%)]   Loss: 2.352356
Train Epoch: 0 [4096/60000 (7%)]        Loss: 7.323082
Train Epoch: 0 [8192/60000 (13%)]       Loss: 4.011040
Train Epoch: 0 [12288/60000 (20%)]      Loss: 2.065250
Train Epoch: 0 [16384/60000 (27%)]      Loss: 2.220759
Train Epoch: 0 [20480/60000 (33%)]      Loss: 1.881386
Train Epoch: 0 [24576/60000 (40%)]      Loss: 1.610603
Train Epoch: 0 [28672/60000 (47%)]      Loss: 1.693122
Train Epoch: 0 [32768/60000 (53%)]      Loss: 1.609484
Train Epoch: 0 [36864/60000 (60%)]      Loss: 1.208902
Train Epoch: 0 [40960/60000 (67%)]      Loss: 1.209118
Train Epoch: 0 [45056/60000 (73%)]      Loss: 0.998975
Train Epoch: 0 [49152/60000 (80%)]      Loss: 0.852635
Train Epoch: 0 [53248/60000 (87%)]      Loss: 0.847645
Train Epoch: 0 [57344/60000 (93%)]      Loss: 0.859491

Test set: Average loss: 0.4077, Accuracy: 8830/10000 (88%)

cuDNN version: 8500

Peak memory allocated:
fused: 3.36GB, unfused: 2.68GB
Memory allocated at end of forward pass:
fused: 0.59GB, unfused: 0.96GB

Total running time of the script: ( 0 minutes 39.074 seconds)

Gallery generated by Sphinx-Gallery


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources