.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_basics_quickstart_tutorial.py: `Learn the Basics `_ || **Quickstart** || `Tensors `_ || `Datasets & DataLoaders `_ || `Transforms `_ || `Build Model `_ || `Autograd `_ || `Optimization `_ || `Save & Load Model `_ Quickstart =================== This section runs through the API for common tasks in machine learning. Refer to the links in each section to dive deeper. Working with data ----------------- PyTorch has two `primitives to work with data `_: ``torch.utils.data.DataLoader`` and ``torch.utils.data.Dataset``. ``Dataset`` stores the samples and their corresponding labels, and ``DataLoader`` wraps an iterable around the ``Dataset``. .. code-block:: default import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor PyTorch offers domain-specific libraries such as `TorchText `_, `TorchVision `_, and `TorchAudio `_, all of which include datasets. For this tutorial, we will be using a TorchVision dataset. The ``torchvision.datasets`` module contains ``Dataset`` objects for many real-world vision data like CIFAR, COCO (`full list here `_). In this tutorial, we use the FashionMNIST dataset. Every TorchVision ``Dataset`` includes two arguments: ``transform`` and ``target_transform`` to modify the samples and labels respectively. .. code-block:: default # Download training data from open datasets. training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), ) # Download test data from open datasets. test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor(), ) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw We pass the ``Dataset`` as an argument to ``DataLoader``. This wraps an iterable over our dataset, and supports automatic batching, sampling, shuffling and multiprocess data loading. Here we define a batch size of 64, i.e. each element in the dataloader iterable will return a batch of 64 features and labels. .. code-block:: default batch_size = 64 # Create data loaders. train_dataloader = DataLoader(training_data, batch_size=batch_size) test_dataloader = DataLoader(test_data, batch_size=batch_size) for X, y in test_dataloader: print(f"Shape of X [N, C, H, W]: {X.shape}") print(f"Shape of y: {y.shape} {y.dtype}") break .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28]) Shape of y: torch.Size([64]) torch.int64 Read more about `loading data in PyTorch `_. -------------- Creating Models ------------------ To define a neural network in PyTorch, we create a class that inherits from `nn.Module `_. We define the layers of the network in the ``__init__`` function and specify how data will pass through the network in the ``forward`` function. To accelerate operations in the neural network, we move it to the GPU if available. .. code-block:: default # Get cpu or gpu device for training. device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using {device} device") # Define model class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10) ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = NeuralNetwork().to(device) print(model) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Using cuda device NeuralNetwork( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) ) Read more about `building neural networks in PyTorch `_. -------------- Optimizing the Model Parameters ---------------------------------------- To train a model, we need a `loss function `_ and an `optimizer `_. .. code-block:: default loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and backpropagates the prediction error to adjust the model's parameters. .. code-block:: default def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) # Compute prediction error pred = model(X) loss = loss_fn(pred, y) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), batch * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") We also check the model's performance against the test dataset to ensure it is learning. .. code-block:: default def test(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") The training process is conducted over several iterations (*epochs*). During each epoch, the model learns parameters to make better predictions. We print the model's accuracy and loss at each epoch; we'd like to see the accuracy increase and the loss decrease with every epoch. .. code-block:: default epochs = 5 for t in range(epochs): print(f"Epoch {t+1}\n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model, loss_fn) print("Done!") .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Epoch 1 ------------------------------- loss: 2.296905 [ 0/60000] loss: 2.286656 [ 6400/60000] loss: 2.267489 [12800/60000] loss: 2.266925 [19200/60000] loss: 2.252868 [25600/60000] loss: 2.227990 [32000/60000] loss: 2.228337 [38400/60000] loss: 2.195647 [44800/60000] loss: 2.176529 [51200/60000] loss: 2.169788 [57600/60000] Test Error: Accuracy: 46.9%, Avg loss: 2.152595 Epoch 2 ------------------------------- loss: 2.153080 [ 0/60000] loss: 2.148765 [ 6400/60000] loss: 2.088664 [12800/60000] loss: 2.110441 [19200/60000] loss: 2.061634 [25600/60000] loss: 2.007747 [32000/60000] loss: 2.025345 [38400/60000] loss: 1.949520 [44800/60000] loss: 1.933236 [51200/60000] loss: 1.891448 [57600/60000] Test Error: Accuracy: 58.2%, Avg loss: 1.875499 Epoch 3 ------------------------------- loss: 1.897956 [ 0/60000] loss: 1.881915 [ 6400/60000] loss: 1.754909 [12800/60000] loss: 1.802715 [19200/60000] loss: 1.697911 [25600/60000] loss: 1.652342 [32000/60000] loss: 1.666433 [38400/60000] loss: 1.570258 [44800/60000] loss: 1.586206 [51200/60000] loss: 1.499081 [57600/60000] Test Error: Accuracy: 60.0%, Avg loss: 1.505721 Epoch 4 ------------------------------- loss: 1.565360 [ 0/60000] loss: 1.545178 [ 6400/60000] loss: 1.385838 [12800/60000] loss: 1.467312 [19200/60000] loss: 1.347984 [25600/60000] loss: 1.344307 [32000/60000] loss: 1.359637 [38400/60000] loss: 1.281757 [44800/60000] loss: 1.315530 [51200/60000] loss: 1.225622 [57600/60000] Test Error: Accuracy: 62.6%, Avg loss: 1.244680 Epoch 5 ------------------------------- loss: 1.316831 [ 0/60000] loss: 1.308723 [ 6400/60000] loss: 1.136861 [12800/60000] loss: 1.249873 [19200/60000] loss: 1.122254 [25600/60000] loss: 1.148926 [32000/60000] loss: 1.173565 [38400/60000] loss: 1.106467 [44800/60000] loss: 1.143653 [51200/60000] loss: 1.064591 [57600/60000] Test Error: Accuracy: 64.4%, Avg loss: 1.081457 Done! Read more about `Training your model `_. -------------- Saving Models ------------- A common way to save a model is to serialize the internal state dictionary (containing the model parameters). .. code-block:: default torch.save(model.state_dict(), "model.pth") print("Saved PyTorch Model State to model.pth") .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Saved PyTorch Model State to model.pth Loading Models ---------------------------- The process for loading a model includes re-creating the model structure and loading the state dictionary into it. .. code-block:: default model = NeuralNetwork() model.load_state_dict(torch.load("model.pth")) This model can now be used to make predictions. .. code-block:: default classes = [ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot", ] model.eval() x, y = test_data[0][0], test_data[0][1] with torch.no_grad(): pred = model(x) predicted, actual = classes[pred[0].argmax(0)], classes[y] print(f'Predicted: "{predicted}", Actual: "{actual}"') .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Predicted: "Ankle boot", Actual: "Ankle boot" Read more about `Saving & Loading your model `_. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 47.987 seconds) .. _sphx_glr_download_beginner_basics_quickstart_tutorial.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: quickstart_tutorial.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: quickstart_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_