.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "recipes/recipes/reasoning_about_shapes.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_recipes_recipes_reasoning_about_shapes.py: Reasoning about Shapes in PyTorch ================================= When writing models with PyTorch, it is commonly the case that the parameters to a given layer depend on the shape of the output of the previous layer. For example, the ``in_features`` of an ``nn.Linear`` layer must match the ``size(-1)`` of the input. For some layers, the shape computation involves complex equations, for example convolution operations. One way around this is to run the forward pass with random inputs, but this is wasteful in terms of memory and compute. Instead, we can make use of the ``meta`` device to determine the output shapes of a layer without materializing any data. .. GENERATED FROM PYTHON SOURCE LINES 17-31 .. code-block:: default import torch import timeit t = torch.rand(2, 3, 10, 10, device="meta") conv = torch.nn.Conv2d(3, 5, 2, device="meta") start = timeit.default_timer() out = conv(t) end = timeit.default_timer() print(out) print(f"Time taken: {end-start}") .. rst-class:: sphx-glr-script-out .. code-block:: none tensor(..., device='meta', size=(2, 5, 9, 9), grad_fn=) Time taken: 0.00011705300039466238 .. GENERATED FROM PYTHON SOURCE LINES 32-34 Observe that since data is not materialized, passing arbitrarily large inputs will not significantly alter the time taken for shape computation. .. GENERATED FROM PYTHON SOURCE LINES 34-44 .. code-block:: default t_large = torch.rand(2**10, 3, 2**16, 2**16, device="meta") start = timeit.default_timer() out = conv(t_large) end = timeit.default_timer() print(out) print(f"Time taken: {end-start}") .. rst-class:: sphx-glr-script-out .. code-block:: none tensor(..., device='meta', size=(1024, 5, 65535, 65535), grad_fn=) Time taken: 7.298099990293849e-05 .. GENERATED FROM PYTHON SOURCE LINES 45-46 Consider an arbitrary network such as the following: .. GENERATED FROM PYTHON SOURCE LINES 46-71 .. code-block:: default import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = torch.flatten(x, 1) # flatten all dimensions except batch x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x .. GENERATED FROM PYTHON SOURCE LINES 72-74 We can view the intermediate shapes within an entire network by registering a forward hook to each layer that prints the shape of the output. .. GENERATED FROM PYTHON SOURCE LINES 74-89 .. code-block:: default def fw_hook(module, input, output): print(f"Shape of output to {module} is {output.shape}.") # Any tensor created within this torch.device context manager will be # on the meta device. with torch.device("meta"): net = Net() inp = torch.randn((1024, 3, 32, 32)) for name, layer in net.named_modules(): layer.register_forward_hook(fw_hook) out = net(inp) .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of output to Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) is torch.Size([1024, 6, 28, 28]). Shape of output to MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is torch.Size([1024, 6, 14, 14]). Shape of output to Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) is torch.Size([1024, 16, 10, 10]). Shape of output to MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is torch.Size([1024, 16, 5, 5]). Shape of output to Linear(in_features=400, out_features=120, bias=True) is torch.Size([1024, 120]). Shape of output to Linear(in_features=120, out_features=84, bias=True) is torch.Size([1024, 84]). Shape of output to Linear(in_features=84, out_features=10, bias=True) is torch.Size([1024, 10]). Shape of output to Net( (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=120, bias=True) (fc2): Linear(in_features=120, out_features=84, bias=True) (fc3): Linear(in_features=84, out_features=10, bias=True) ) is torch.Size([1024, 10]). .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.010 seconds) .. _sphx_glr_download_recipes_recipes_reasoning_about_shapes.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: reasoning_about_shapes.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: reasoning_about_shapes.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_