Shortcuts

Reasoning about Shapes in PyTorch

When writing models with PyTorch, it is commonly the case that the parameters to a given layer depend on the shape of the output of the previous layer. For example, the in_features of an nn.Linear layer must match the size(-1) of the input. For some layers, the shape computation involves complex equations, for example convolution operations.

One way around this is to run the forward pass with random inputs, but this is wasteful in terms of memory and compute.

Instead, we can make use of the meta device to determine the output shapes of a layer without materializing any data.

import torch
import timeit

t = torch.rand(2, 3, 10, 10, device="meta")
conv = torch.nn.Conv2d(3, 5, 2, device="meta")
start = timeit.default_timer()
out = conv(t)
end = timeit.default_timer()

print(out)
print(f"Time taken: {end-start}")
tensor(..., device='meta', size=(2, 5, 9, 9), grad_fn=<ConvolutionBackward0>)
Time taken: 0.0001575129999764613

Observe that since data is not materialized, passing arbitrarily large inputs will not significantly alter the time taken for shape computation.

t_large = torch.rand(2**10, 3, 2**16, 2**16, device="meta")
start = timeit.default_timer()
out = conv(t_large)
end = timeit.default_timer()

print(out)
print(f"Time taken: {end-start}")
tensor(..., device='meta', size=(1024, 5, 65535, 65535),
       grad_fn=<ConvolutionBackward0>)
Time taken: 0.00010942800054181134

Consider an arbitrary network such as the following:

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

We can view the intermediate shapes within an entire network by registering a forward hook to each layer that prints the shape of the output.

def fw_hook(module, input, output):
    print(f"Shape of output to {module} is {output.shape}.")


# Any tensor created within this torch.device context manager will be
# on the meta device.
with torch.device("meta"):
    net = Net()
    inp = torch.randn((1024, 3, 32, 32))

for name, layer in net.named_modules():
    layer.register_forward_hook(fw_hook)

out = net(inp)
Shape of output to Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) is torch.Size([1024, 6, 28, 28]).
Shape of output to MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is torch.Size([1024, 6, 14, 14]).
Shape of output to Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) is torch.Size([1024, 16, 10, 10]).
Shape of output to MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is torch.Size([1024, 16, 5, 5]).
Shape of output to Linear(in_features=400, out_features=120, bias=True) is torch.Size([1024, 120]).
Shape of output to Linear(in_features=120, out_features=84, bias=True) is torch.Size([1024, 84]).
Shape of output to Linear(in_features=84, out_features=10, bias=True) is torch.Size([1024, 10]).
Shape of output to Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
) is torch.Size([1024, 10]).

Total running time of the script: ( 0 minutes 0.014 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources