.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "advanced/numpy_extensions_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_advanced_numpy_extensions_tutorial.py: Creating Extensions Using NumPy and SciPy ========================================= **Author**: `Adam Paszke `_ **Updated by**: `Adam Dziedzic `_ In this tutorial, we shall go through two tasks: 1. Create a neural network layer with no parameters. - This calls into **numpy** as part of its implementation 2. Create a neural network layer that has learnable weights - This calls into **SciPy** as part of its implementation .. GENERATED FROM PYTHON SOURCE LINES 19-23 .. code-block:: default import torch from torch.autograd import Function .. GENERATED FROM PYTHON SOURCE LINES 24-33 Parameter-less example ---------------------- This layer doesn’t particularly do anything useful or mathematically correct. It is aptly named ``BadFFTFunction`` **Layer Implementation** .. GENERATED FROM PYTHON SOURCE LINES 33-57 .. code-block:: default from numpy.fft import rfft2, irfft2 class BadFFTFunction(Function): @staticmethod def forward(ctx, input): numpy_input = input.detach().numpy() result = abs(rfft2(numpy_input)) return input.new(result) @staticmethod def backward(ctx, grad_output): numpy_go = grad_output.numpy() result = irfft2(numpy_go) return grad_output.new(result) # since this layer does not have any parameters, we can # simply declare this as a function, rather than as an ``nn.Module`` class def incorrect_fft(input): return BadFFTFunction.apply(input) .. GENERATED FROM PYTHON SOURCE LINES 58-59 **Example usage of the created layer:** .. GENERATED FROM PYTHON SOURCE LINES 59-66 .. code-block:: default input = torch.randn(8, 8, requires_grad=True) result = incorrect_fft(input) print(result) result.backward(torch.randn(result.size())) print(input) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ 2.1007, 2.6758, 6.1699, 3.5140, 16.9024], [ 9.2379, 13.8331, 9.7708, 9.5670, 9.6625], [11.8466, 8.2359, 6.3284, 5.9224, 10.6756], [ 4.5172, 3.8650, 8.0120, 1.7539, 9.6564], [ 0.0610, 2.5647, 3.8037, 7.6060, 20.8580], [ 4.5172, 13.6494, 8.3293, 4.8767, 9.6564], [11.8466, 4.2204, 8.4846, 8.4208, 10.6756], [ 9.2379, 12.2773, 3.2013, 2.4911, 9.6625]], grad_fn=) tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, 0.6784, -1.2345, -0.0431, -1.6047], [-0.7521, 1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688, 0.7624], [ 1.6423, -0.1596, -0.4974, 0.4396, -0.7581, 1.0783, 0.8008, 1.6806], [ 1.2791, 1.2964, 0.6105, 1.3347, -0.2316, 0.0418, -0.2516, 0.8599], [-1.3847, -0.8712, -0.2234, 1.7174, 0.3189, -0.4245, 0.3057, -0.7746], [-1.5576, 0.9956, -0.8798, -0.6011, -1.2742, 2.1228, -1.2347, -0.4879], [-0.9138, -0.6581, 0.0780, 0.5258, -0.4880, 1.1914, -0.8140, -0.7360], [-1.4032, 0.0360, -0.0635, 0.6756, -0.0978, 1.8446, -1.1845, 1.3835]], requires_grad=True) .. GENERATED FROM PYTHON SOURCE LINES 67-79 Parametrized example -------------------- In deep learning literature, this layer is confusingly referred to as convolution while the actual operation is cross-correlation (the only difference is that filter is flipped for convolution, which is not the case for cross-correlation). Implementation of a layer with learnable weights, where cross-correlation has a filter (kernel) that represents weights. The backward pass computes the gradient ``wrt`` the input and the gradient ``wrt`` the filter. .. GENERATED FROM PYTHON SOURCE LINES 79-120 .. code-block:: default from numpy import flip import numpy as np from scipy.signal import convolve2d, correlate2d from torch.nn.modules.module import Module from torch.nn.parameter import Parameter class ScipyConv2dFunction(Function): @staticmethod def forward(ctx, input, filter, bias): # detach so we can cast to NumPy input, filter, bias = input.detach(), filter.detach(), bias.detach() result = correlate2d(input.numpy(), filter.numpy(), mode='valid') result += bias.numpy() ctx.save_for_backward(input, filter, bias) return torch.as_tensor(result, dtype=input.dtype) @staticmethod def backward(ctx, grad_output): grad_output = grad_output.detach() input, filter, bias = ctx.saved_tensors grad_output = grad_output.numpy() grad_bias = np.sum(grad_output, keepdims=True) grad_input = convolve2d(grad_output, filter.numpy(), mode='full') # the previous line can be expressed equivalently as: # grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full') grad_filter = correlate2d(input.numpy(), grad_output, mode='valid') return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float) class ScipyConv2d(Module): def __init__(self, filter_width, filter_height): super(ScipyConv2d, self).__init__() self.filter = Parameter(torch.randn(filter_width, filter_height)) self.bias = Parameter(torch.randn(1, 1)) def forward(self, input): return ScipyConv2dFunction.apply(input, self.filter, self.bias) .. GENERATED FROM PYTHON SOURCE LINES 121-122 **Example usage:** .. GENERATED FROM PYTHON SOURCE LINES 122-131 .. code-block:: default module = ScipyConv2d(3, 3) print("Filter and bias: ", list(module.parameters())) input = torch.randn(10, 10, requires_grad=True) output = module(input) print("Output from the convolution: ", output) output.backward(torch.randn(8, 8)) print("Gradient for the input map: ", input.grad) .. rst-class:: sphx-glr-script-out .. code-block:: none Filter and bias: [Parameter containing: tensor([[ 1.5980, 0.1115, -0.0392], [ 1.4112, -0.6556, 0.8576], [-1.6270, -1.3951, -0.2387]], requires_grad=True), Parameter containing: tensor([[-0.5050]], requires_grad=True)] Output from the convolution: tensor([[ 2.6490, 0.9049, -2.4603, 4.9172, -2.6995, -0.5259, 0.8084, 2.9712], [-7.4314, 2.0491, 1.1998, 2.9348, -1.0925, 2.3583, -0.4701, 0.0620], [ 0.4747, 2.7422, -3.1890, -1.3733, -0.2700, 5.3886, 4.3234, 0.7127], [ 1.9092, -2.2850, -5.8786, -0.0514, -3.9709, -1.3217, -4.1159, 0.1144], [-5.2757, -7.3949, 4.8335, -1.4541, -1.7630, -2.9781, -3.6742, 0.6746], [-1.3930, 1.5531, 1.3611, -5.1453, 1.4953, 0.2460, -2.6832, 0.3069], [-2.9077, -2.0569, -2.8977, -1.1425, 2.6128, 1.0830, -3.7180, 4.4202], [-3.3766, -2.4025, -1.3721, 2.2309, -4.1819, -1.2281, -4.5500, 0.4863]], grad_fn=) Gradient for the input map: tensor([[-1.0954, 0.8242, -2.3188, -2.7640, -2.4841, 0.5890, -0.0936, 1.2190, 0.0921, -0.0307], [-3.5222, 0.8920, -1.8908, -0.6549, -0.2606, 0.7365, -4.9122, 1.3548, -0.5405, 0.6700], [-3.3450, 1.7799, -2.4677, 3.7305, 7.3472, 0.0637, 0.7740, 0.6444, -2.8296, -0.1736], [ 2.5613, 7.4243, -4.5369, 2.7349, -3.9802, -1.3710, 5.9015, 0.9176, 1.8604, 0.0468], [ 5.4830, 1.3442, 3.4260, 5.3385, 0.8872, 3.3689, -2.8774, -1.6875, -0.2114, -0.0434], [ 2.6617, -6.2559, -1.9922, -2.9800, -2.8432, -2.9184, -1.1842, 0.9845, -0.5682, -0.0600], [ 1.9133, -1.7688, 6.1904, 0.9140, 2.1316, 3.6076, -0.3446, -1.0438, 0.8684, 0.6735], [-2.4305, -7.8832, -0.8278, -6.6322, -0.5148, 0.2151, -4.0533, -3.7276, -2.6069, -1.2354], [-0.4031, -2.5208, 0.9026, -2.6378, -0.1116, -2.9007, 2.9866, 2.7905, 4.0234, -0.5354], [-0.2927, 3.2093, 3.1423, 2.3881, 2.7282, 1.4053, -1.7006, -0.0692, 1.0831, 0.2345]]) .. GENERATED FROM PYTHON SOURCE LINES 132-133 **Check the gradients:** .. GENERATED FROM PYTHON SOURCE LINES 133-141 .. code-block:: default from torch.autograd.gradcheck import gradcheck moduleConv = ScipyConv2d(3, 3) input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)] test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4) print("Are the gradients correct: ", test) .. rst-class:: sphx-glr-script-out .. code-block:: none Are the gradients correct: True .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.577 seconds) .. _sphx_glr_download_advanced_numpy_extensions_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: numpy_extensions_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: numpy_extensions_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_