.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/introyt/introyt1_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_introyt_introyt1_tutorial.py: **Introduction** || `Tensors `_ || `Autograd `_ || `Building Models `_ || `TensorBoard Support `_ || `Training Models `_ || `Model Understanding `_ Introduction to PyTorch ======================= Follow along with the video below or on `youtube `__. .. raw:: html
PyTorch Tensors --------------- Follow along with the video beginning at `03:50 `__. First, we’ll import pytorch. .. GENERATED FROM PYTHON SOURCE LINES 29-32 .. code-block:: default import torch .. GENERATED FROM PYTHON SOURCE LINES 33-36 Let’s see a few basic tensor manipulations. First, just a few of the ways to create tensors: .. GENERATED FROM PYTHON SOURCE LINES 36-42 .. code-block:: default z = torch.zeros(5, 3) print(z) print(z.dtype) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]) torch.float32 .. GENERATED FROM PYTHON SOURCE LINES 43-50 Above, we create a 5x3 matrix filled with zeros, and query its datatype to find out that the zeros are 32-bit floating point numbers, which is the default PyTorch. What if you wanted integers instead? You can always override the default: .. GENERATED FROM PYTHON SOURCE LINES 50-55 .. code-block:: default i = torch.ones((5, 3), dtype=torch.int16) print(i) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=torch.int16) .. GENERATED FROM PYTHON SOURCE LINES 56-62 You can see that when we do change the default, the tensor helpfully reports this when printed. It’s common to initialize learning weights randomly, often with a specific seed for the PRNG for reproducibility of results: .. GENERATED FROM PYTHON SOURCE LINES 62-78 .. code-block:: default torch.manual_seed(1729) r1 = torch.rand(2, 2) print('A random tensor:') print(r1) r2 = torch.rand(2, 2) print('\nA different random tensor:') print(r2) # new values torch.manual_seed(1729) r3 = torch.rand(2, 2) print('\nShould match r1:') print(r3) # repeats values of r1 because of re-seed .. rst-class:: sphx-glr-script-out .. code-block:: none A random tensor: tensor([[0.3126, 0.3791], [0.3087, 0.0736]]) A different random tensor: tensor([[0.4216, 0.0691], [0.2332, 0.4047]]) Should match r1: tensor([[0.3126, 0.3791], [0.3087, 0.0736]]) .. GENERATED FROM PYTHON SOURCE LINES 79-83 PyTorch tensors perform arithmetic operations intuitively. Tensors of similar shapes may be added, multiplied, etc. Operations with scalars are distributed over the tensor: .. GENERATED FROM PYTHON SOURCE LINES 83-100 .. code-block:: default ones = torch.ones(2, 3) print(ones) twos = torch.ones(2, 3) * 2 # every element is multiplied by 2 print(twos) threes = ones + twos # addition allowed because shapes are similar print(threes) # tensors are added element-wise print(threes.shape) # this has the same dimensions as input tensors r1 = torch.rand(2, 3) r2 = torch.rand(3, 2) # uncomment this line to get a runtime error # r3 = r1 + r2 .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[1., 1., 1.], [1., 1., 1.]]) tensor([[2., 2., 2.], [2., 2., 2.]]) tensor([[3., 3., 3.], [3., 3., 3.]]) torch.Size([2, 3]) .. GENERATED FROM PYTHON SOURCE LINES 101-103 Here’s a small sample of the mathematical operations available: .. GENERATED FROM PYTHON SOURCE LINES 103-129 .. code-block:: default r = (torch.rand(2, 2) - 0.5) * 2 # values between -1 and 1 print('A random matrix, r:') print(r) # Common mathematical operations are supported: print('\nAbsolute value of r:') print(torch.abs(r)) # ...as are trigonometric functions: print('\nInverse sine of r:') print(torch.asin(r)) # ...and linear algebra operations like determinant and singular value decomposition print('\nDeterminant of r:') print(torch.det(r)) print('\nSingular value decomposition of r:') print(torch.svd(r)) # ...and statistical and aggregate operations: print('\nAverage and standard deviation of r:') print(torch.std_mean(r)) print('\nMaximum value of r:') print(torch.max(r)) .. rst-class:: sphx-glr-script-out .. code-block:: none A random matrix, r: tensor([[ 0.9956, -0.2232], [ 0.3858, -0.6593]]) Absolute value of r: tensor([[0.9956, 0.2232], [0.3858, 0.6593]]) Inverse sine of r: tensor([[ 1.4775, -0.2251], [ 0.3961, -0.7199]]) Determinant of r: tensor(-0.5703) Singular value decomposition of r: torch.return_types.svd( U=tensor([[-0.8353, -0.5497], [-0.5497, 0.8353]]), S=tensor([1.1793, 0.4836]), V=tensor([[-0.8851, -0.4654], [ 0.4654, -0.8851]])) Average and standard deviation of r: (tensor(0.7217), tensor(0.1247)) Maximum value of r: tensor(0.9956) .. GENERATED FROM PYTHON SOURCE LINES 130-141 There’s a good deal more to know about the power of PyTorch tensors, including how to set them up for parallel computations on GPU - we’ll be going into more depth in another video. PyTorch Models -------------- Follow along with the video beginning at `10:00 `__. Let’s talk about how we can express models in PyTorch .. GENERATED FROM PYTHON SOURCE LINES 141-147 .. code-block:: default import torch # for all things PyTorch import torch.nn as nn # for torch.nn.Module, the parent object for PyTorch models import torch.nn.functional as F # for the activation function .. GENERATED FROM PYTHON SOURCE LINES 148-174 .. figure:: /_static/img/mnist.png :alt: le-net-5 diagram *Figure: LeNet-5* Above is a diagram of LeNet-5, one of the earliest convolutional neural nets, and one of the drivers of the explosion in Deep Learning. It was built to read small images of handwritten numbers (the MNIST dataset), and correctly classify which digit was represented in the image. Here’s the abridged version of how it works: - Layer C1 is a convolutional layer, meaning that it scans the input image for features it learned during training. It outputs a map of where it saw each of its learned features in the image. This “activation map” is downsampled in layer S2. - Layer C3 is another convolutional layer, this time scanning C1’s activation map for *combinations* of features. It also puts out an activation map describing the spatial locations of these feature combinations, which is downsampled in layer S4. - Finally, the fully-connected layers at the end, F5, F6, and OUTPUT, are a *classifier* that takes the final activation map, and classifies it into one of ten bins representing the 10 digits. How do we express this simple neural network in code? .. GENERATED FROM PYTHON SOURCE LINES 174-207 .. code-block:: default class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() # 1 input image channel (black & white), 6 output channels, 5x5 square convolution # kernel self.conv1 = nn.Conv2d(1, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) # an affine operation: y = Wx + b self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): # Max pooling over a (2, 2) window x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # If the size is a square you can only specify a single number x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = x.view(-1, self.num_flat_features(x)) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def num_flat_features(self, x): size = x.size()[1:] # all dimensions except the batch dimension num_features = 1 for s in size: num_features *= s return num_features .. GENERATED FROM PYTHON SOURCE LINES 208-228 Looking over this code, you should be able to spot some structural similarities with the diagram above. This demonstrates the structure of a typical PyTorch model: - It inherits from ``torch.nn.Module`` - modules may be nested - in fact, even the ``Conv2d`` and ``Linear`` layer classes inherit from ``torch.nn.Module``. - A model will have an ``__init__()`` function, where it instantiates its layers, and loads any data artifacts it might need (e.g., an NLP model might load a vocabulary). - A model will have a ``forward()`` function. This is where the actual computation happens: An input is passed through the network layers and various functions to generate an output. - Other than that, you can build out your model class like any other Python class, adding whatever properties and methods you need to support your model’s computation. Let’s instantiate this object and run a sample input through it. .. GENERATED FROM PYTHON SOURCE LINES 228-242 .. code-block:: default net = LeNet() print(net) # what does the object tell us about itself? input = torch.rand(1, 1, 32, 32) # stand-in for a 32x32 black & white image print('\nImage batch shape:') print(input.shape) output = net(input) # we don't call forward() directly print('\nRaw output:') print(output) print(output.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none LeNet( (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1)) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=120, bias=True) (fc2): Linear(in_features=120, out_features=84, bias=True) (fc3): Linear(in_features=84, out_features=10, bias=True) ) Image batch shape: torch.Size([1, 1, 32, 32]) Raw output: tensor([[ 0.0898, 0.0318, 0.1485, 0.0301, -0.0085, -0.1135, -0.0296, 0.0164, 0.0039, 0.0616]], grad_fn=) torch.Size([1, 10]) .. GENERATED FROM PYTHON SOURCE LINES 243-282 There are a few important things happening above: First, we instantiate the ``LeNet`` class, and we print the ``net`` object. A subclass of ``torch.nn.Module`` will report the layers it has created and their shapes and parameters. This can provide a handy overview of a model if you want to get the gist of its processing. Below that, we create a dummy input representing a 32x32 image with 1 color channel. Normally, you would load an image tile and convert it to a tensor of this shape. You may have noticed an extra dimension to our tensor - the *batch dimension.* PyTorch models assume they are working on *batches* of data - for example, a batch of 16 of our image tiles would have the shape ``(16, 1, 32, 32)``. Since we’re only using one image, we create a batch of 1 with shape ``(1, 1, 32, 32)``. We ask the model for an inference by calling it like a function: ``net(input)``. The output of this call represents the model’s confidence that the input represents a particular digit. (Since this instance of the model hasn’t learned anything yet, we shouldn’t expect to see any signal in the output.) Looking at the shape of ``output``, we can see that it also has a batch dimension, the size of which should always match the input batch dimension. If we had passed in an input batch of 16 instances, ``output`` would have a shape of ``(16, 10)``. Datasets and Dataloaders ------------------------ Follow along with the video beginning at `14:00 `__. Below, we’re going to demonstrate using one of the ready-to-download, open-access datasets from TorchVision, how to transform the images for consumption by your model, and how to use the DataLoader to feed batches of data to your model. The first thing we need to do is transform our incoming images into a PyTorch tensor. .. GENERATED FROM PYTHON SOURCE LINES 282-294 .. code-block:: default #%matplotlib inline import torch import torchvision import torchvision.transforms as transforms transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]) .. GENERATED FROM PYTHON SOURCE LINES 295-331 Here, we specify two transformations for our input: - ``transforms.ToTensor()`` converts images loaded by Pillow into PyTorch tensors. - ``transforms.Normalize()`` adjusts the values of the tensor so that their average is zero and their standard deviation is 1.0. Most activation functions have their strongest gradients around x = 0, so centering our data there can speed learning. The values passed to the transform are the means (first tuple) and the standard deviations (second tuple) of the rgb values of the images in the dataset. You can calculate these values yourself by running these few lines of code: ``` from torch.utils.data import ConcatDataset transform = transforms.Compose([transforms.ToTensor()]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) #stack all train images together into a tensor of shape #(50000, 3, 32, 32) x = torch.stack([sample[0] for sample in ConcatDataset([trainset])]) #get the mean of each channel mean = torch.mean(x, dim=(0,2,3)) #tensor([0.4914, 0.4822, 0.4465]) std = torch.std(x, dim=(0,2,3)) #tensor([0.2470, 0.2435, 0.2616]) ``` There are many more transforms available, including cropping, centering, rotation, and reflection. Next, we’ll create an instance of the CIFAR10 dataset. This is a set of 32x32 color image tiles representing 10 classes of objects: 6 of animals (bird, cat, deer, dog, frog, horse) and 4 of vehicles (airplane, automobile, ship, truck): .. GENERATED FROM PYTHON SOURCE LINES 331-336 .. code-block:: default trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz 0%| | 0/170498071 [00:00`__. Let’s put all the pieces together, and train a model: .. GENERATED FROM PYTHON SOURCE LINES 410-426 .. code-block:: default #%matplotlib inline import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.transforms as transforms import matplotlib import matplotlib.pyplot as plt import numpy as np .. GENERATED FROM PYTHON SOURCE LINES 427-431 First, we’ll need training and test datasets. If you haven’t already, run the cell below to make sure the dataset is downloaded. (It may take a minute.) .. GENERATED FROM PYTHON SOURCE LINES 431-450 .. code-block:: default transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') .. rst-class:: sphx-glr-script-out .. code-block:: none Files already downloaded and verified Files already downloaded and verified .. GENERATED FROM PYTHON SOURCE LINES 451-453 We’ll run our check on the output from ``DataLoader``: .. GENERATED FROM PYTHON SOURCE LINES 453-476 .. code-block:: default import matplotlib.pyplot as plt import numpy as np # functions to show an image def imshow(img): img = img / 2 + 0.5 # unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) # get some random training images dataiter = iter(trainloader) images, labels = next(dataiter) # show images imshow(torchvision.utils.make_grid(images)) # print labels print(' '.join('%5s' % classes[labels[j]] for j in range(4))) .. image-sg:: /beginner/introyt/images/sphx_glr_introyt1_tutorial_002.png :alt: introyt1 tutorial :srcset: /beginner/introyt/images/sphx_glr_introyt1_tutorial_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none cat cat deer frog .. GENERATED FROM PYTHON SOURCE LINES 477-481 This is the model we’ll train. If it looks familiar, that’s because it’s a variant of LeNet - discussed earlier in this video - adapted for 3-color images. .. GENERATED FROM PYTHON SOURCE LINES 481-505 .. code-block:: default class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() .. GENERATED FROM PYTHON SOURCE LINES 506-508 The last ingredients we need are a loss function and an optimizer: .. GENERATED FROM PYTHON SOURCE LINES 508-513 .. code-block:: default criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) .. GENERATED FROM PYTHON SOURCE LINES 514-528 The loss function, as discussed earlier in this video, is a measure of how far from our ideal output the model’s prediction was. Cross-entropy loss is a typical loss function for classification models like ours. The **optimizer** is what drives the learning. Here we have created an optimizer that implements *stochastic gradient descent,* one of the more straightforward optimization algorithms. Besides parameters of the algorithm, like the learning rate (``lr``) and momentum, we also pass in ``net.parameters()``, which is a collection of all the learning weights in the model - which is what the optimizer adjusts. Finally, all of this is assembled into the training loop. Go ahead and run this cell, as it will likely take a few minutes to execute: .. GENERATED FROM PYTHON SOURCE LINES 528-555 .. code-block:: default for epoch in range(2): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): # get the inputs inputs, labels = data # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i % 2000 == 1999: # print every 2000 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 print('Finished Training') .. rst-class:: sphx-glr-script-out .. code-block:: none [1, 2000] loss: 2.195 [1, 4000] loss: 1.876 [1, 6000] loss: 1.655 [1, 8000] loss: 1.576 [1, 10000] loss: 1.519 [1, 12000] loss: 1.466 [2, 2000] loss: 1.421 [2, 4000] loss: 1.376 [2, 6000] loss: 1.336 [2, 8000] loss: 1.335 [2, 10000] loss: 1.326 [2, 12000] loss: 1.270 Finished Training .. GENERATED FROM PYTHON SOURCE LINES 556-612 Here, we are doing only **2 training epochs** (line 1) - that is, two passes over the training dataset. Each pass has an inner loop that **iterates over the training data** (line 4), serving batches of transformed input images and their correct labels. **Zeroing the gradients** (line 9) is an important step. Gradients are accumulated over a batch; if we do not reset them for every batch, they will keep accumulating, which will provide incorrect gradient values, making learning impossible. In line 12, we **ask the model for its predictions** on this batch. In the following line (13), we compute the loss - the difference between ``outputs`` (the model prediction) and ``labels`` (the correct output). In line 14, we do the ``backward()`` pass, and calculate the gradients that will direct the learning. In line 15, the optimizer performs one learning step - it uses the gradients from the ``backward()`` call to nudge the learning weights in the direction it thinks will reduce the loss. The remainder of the loop does some light reporting on the epoch number, how many training instances have been completed, and what the collected loss is over the training loop. **When you run the cell above,** you should see something like this: .. code-block:: sh [1, 2000] loss: 2.235 [1, 4000] loss: 1.940 [1, 6000] loss: 1.713 [1, 8000] loss: 1.573 [1, 10000] loss: 1.507 [1, 12000] loss: 1.442 [2, 2000] loss: 1.378 [2, 4000] loss: 1.364 [2, 6000] loss: 1.349 [2, 8000] loss: 1.319 [2, 10000] loss: 1.284 [2, 12000] loss: 1.267 Finished Training Note that the loss is monotonically descending, indicating that our model is continuing to improve its performance on the training dataset. As a final step, we should check that the model is actually doing *general* learning, and not simply “memorizing” the dataset. This is called **overfitting,** and usually indicates that the dataset is too small (not enough examples for general learning), or that the model has more learning parameters than it needs to correctly model the dataset. This is the reason datasets are split into training and test subsets - to test the generality of the model, we ask it to make predictions on data it hasn’t trained on: .. GENERATED FROM PYTHON SOURCE LINES 612-627 .. code-block:: default correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the 10000 test images: %d %%' % ( 100 * correct / total)) .. rst-class:: sphx-glr-script-out .. code-block:: none Accuracy of the network on the 10000 test images: 54 % .. GENERATED FROM PYTHON SOURCE LINES 628-633 If you followed along, you should see that the model is roughly 50% accurate at this point. That’s not exactly state-of-the-art, but it’s far better than the 10% accuracy we’d expect from a random output. This demonstrates that some general learning did happen in the model. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 1 minutes 53.964 seconds) .. _sphx_glr_download_beginner_introyt_introyt1_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: introyt1_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: introyt1_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_