• Tutorials >
  • Creating Extensions Using NumPy and SciPy
Shortcuts

Creating Extensions Using NumPy and SciPy

Author: Adam Paszke

Updated by: Adam Dziedzic

In this tutorial, we shall go through two tasks:

  1. Create a neural network layer with no parameters.

    • This calls into numpy as part of its implementation

  2. Create a neural network layer that has learnable weights

    • This calls into SciPy as part of its implementation

import torch
from torch.autograd import Function

Parameter-less example

This layer doesn’t particularly do anything useful or mathematically correct.

It is aptly named BadFFTFunction

Layer Implementation

from numpy.fft import rfft2, irfft2


class BadFFTFunction(Function):
    @staticmethod
    def forward(ctx, input):
        numpy_input = input.detach().numpy()
        result = abs(rfft2(numpy_input))
        return input.new(result)

    @staticmethod
    def backward(ctx, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return grad_output.new(result)

# since this layer does not have any parameters, we can
# simply declare this as a function, rather than as an ``nn.Module`` class


def incorrect_fft(input):
    return BadFFTFunction.apply(input)

Example usage of the created layer:

input = torch.randn(8, 8, requires_grad=True)
result = incorrect_fft(input)
print(result)
result.backward(torch.randn(result.size()))
print(input)
tensor([[ 2.1007,  2.6758,  6.1699,  3.5140, 16.9024],
        [ 9.2379, 13.8331,  9.7708,  9.5670,  9.6625],
        [11.8466,  8.2359,  6.3284,  5.9224, 10.6756],
        [ 4.5172,  3.8650,  8.0120,  1.7539,  9.6564],
        [ 0.0610,  2.5647,  3.8037,  7.6060, 20.8580],
        [ 4.5172, 13.6494,  8.3293,  4.8767,  9.6564],
        [11.8466,  4.2204,  8.4846,  8.4208, 10.6756],
        [ 9.2379, 12.2773,  3.2013,  2.4911,  9.6625]],
       grad_fn=<BadFFTFunctionBackward>)
tensor([[ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431, -1.6047],
        [-0.7521,  1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688,  0.7624],
        [ 1.6423, -0.1596, -0.4974,  0.4396, -0.7581,  1.0783,  0.8008,  1.6806],
        [ 1.2791,  1.2964,  0.6105,  1.3347, -0.2316,  0.0418, -0.2516,  0.8599],
        [-1.3847, -0.8712, -0.2234,  1.7174,  0.3189, -0.4245,  0.3057, -0.7746],
        [-1.5576,  0.9956, -0.8798, -0.6011, -1.2742,  2.1228, -1.2347, -0.4879],
        [-0.9138, -0.6581,  0.0780,  0.5258, -0.4880,  1.1914, -0.8140, -0.7360],
        [-1.4032,  0.0360, -0.0635,  0.6756, -0.0978,  1.8446, -1.1845,  1.3835]],
       requires_grad=True)

Parametrized example

In deep learning literature, this layer is confusingly referred to as convolution while the actual operation is cross-correlation (the only difference is that filter is flipped for convolution, which is not the case for cross-correlation).

Implementation of a layer with learnable weights, where cross-correlation has a filter (kernel) that represents weights.

The backward pass computes the gradient wrt the input and the gradient wrt the filter.

from numpy import flip
import numpy as np
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class ScipyConv2dFunction(Function):
    @staticmethod
    def forward(ctx, input, filter, bias):
        # detach so we can cast to NumPy
        input, filter, bias = input.detach(), filter.detach(), bias.detach()
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        result += bias.numpy()
        ctx.save_for_backward(input, filter, bias)
        return torch.as_tensor(result, dtype=input.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input, filter, bias = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_bias = np.sum(grad_output, keepdims=True)
        grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
        # the previous line can be expressed equivalently as:
        # grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
        grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
        return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float)


class ScipyConv2d(Module):
    def __init__(self, filter_width, filter_height):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(filter_width, filter_height))
        self.bias = Parameter(torch.randn(1, 1))

    def forward(self, input):
        return ScipyConv2dFunction.apply(input, self.filter, self.bias)

Example usage:

module = ScipyConv2d(3, 3)
print("Filter and bias: ", list(module.parameters()))
input = torch.randn(10, 10, requires_grad=True)
output = module(input)
print("Output from the convolution: ", output)
output.backward(torch.randn(8, 8))
print("Gradient for the input map: ", input.grad)
Filter and bias:  [Parameter containing:
tensor([[ 1.5980,  0.1115, -0.0392],
        [ 1.4112, -0.6556,  0.8576],
        [-1.6270, -1.3951, -0.2387]], requires_grad=True), Parameter containing:
tensor([[-0.5050]], requires_grad=True)]
Output from the convolution:  tensor([[ 2.6490,  0.9049, -2.4603,  4.9172, -2.6995, -0.5259,  0.8084,  2.9712],
        [-7.4314,  2.0491,  1.1998,  2.9348, -1.0925,  2.3583, -0.4701,  0.0620],
        [ 0.4747,  2.7422, -3.1890, -1.3733, -0.2700,  5.3886,  4.3234,  0.7127],
        [ 1.9092, -2.2850, -5.8786, -0.0514, -3.9709, -1.3217, -4.1159,  0.1144],
        [-5.2757, -7.3949,  4.8335, -1.4541, -1.7630, -2.9781, -3.6742,  0.6746],
        [-1.3930,  1.5531,  1.3611, -5.1453,  1.4953,  0.2460, -2.6832,  0.3069],
        [-2.9077, -2.0569, -2.8977, -1.1425,  2.6128,  1.0830, -3.7180,  4.4202],
        [-3.3766, -2.4025, -1.3721,  2.2309, -4.1819, -1.2281, -4.5500,  0.4863]],
       grad_fn=<ScipyConv2dFunctionBackward>)
Gradient for the input map:  tensor([[-1.0954,  0.8242, -2.3188, -2.7640, -2.4841,  0.5890, -0.0936,  1.2190,
          0.0921, -0.0307],
        [-3.5222,  0.8920, -1.8908, -0.6549, -0.2606,  0.7365, -4.9122,  1.3548,
         -0.5405,  0.6700],
        [-3.3450,  1.7799, -2.4677,  3.7305,  7.3472,  0.0637,  0.7740,  0.6444,
         -2.8296, -0.1736],
        [ 2.5613,  7.4243, -4.5369,  2.7349, -3.9802, -1.3710,  5.9015,  0.9176,
          1.8604,  0.0468],
        [ 5.4830,  1.3442,  3.4260,  5.3385,  0.8872,  3.3689, -2.8774, -1.6875,
         -0.2114, -0.0434],
        [ 2.6617, -6.2559, -1.9922, -2.9800, -2.8432, -2.9184, -1.1842,  0.9845,
         -0.5682, -0.0600],
        [ 1.9133, -1.7688,  6.1904,  0.9140,  2.1316,  3.6076, -0.3446, -1.0438,
          0.8684,  0.6735],
        [-2.4305, -7.8832, -0.8278, -6.6322, -0.5148,  0.2151, -4.0533, -3.7276,
         -2.6069, -1.2354],
        [-0.4031, -2.5208,  0.9026, -2.6378, -0.1116, -2.9007,  2.9866,  2.7905,
          4.0234, -0.5354],
        [-0.2927,  3.2093,  3.1423,  2.3881,  2.7282,  1.4053, -1.7006, -0.0692,
          1.0831,  0.2345]])

Check the gradients:

from torch.autograd.gradcheck import gradcheck

moduleConv = ScipyConv2d(3, 3)

input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)]
test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4)
print("Are the gradients correct: ", test)
Are the gradients correct:  True

Total running time of the script: ( 0 minutes 0.566 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources