.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_basics_buildmodel_tutorial.py: `Learn the Basics `_ || `Quickstart `_ || `Tensors `_ || `Datasets & DataLoaders `_ || `Transforms `_ || **Build Model** || `Autograd `_ || `Optimization `_ || `Save & Load Model `_ Build the Neural Network =================== Neural networks comprise of layers/modules that perform operations on data. The `torch.nn `_ namespace provides all the building blocks you need to build your own neural network. Every module in PyTorch subclasses the `nn.Module `_. A neural network is a module itself that consists of other modules (layers). This nested structure allows for building and managing complex architectures easily. In the following sections, we'll build a neural network to classify images in the FashionMNIST dataset. .. code-block:: default import os import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms Get Device for Training ----------------------- We want to be able to train our model on a hardware accelerator like the GPU, if it is available. Let's check to see if `torch.cuda `_ is available, else we continue to use the CPU. .. code-block:: default device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using {device} device") .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Using cuda device Define the Class ------------------------- We define our neural network by subclassing ``nn.Module``, and initialize the neural network layers in ``__init__``. Every ``nn.Module`` subclass implements the operations on input data in the ``forward`` method. .. code-block:: default class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits We create an instance of ``NeuralNetwork``, and move it to the ``device``, and print its structure. .. code-block:: default model = NeuralNetwork().to(device) print(model) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none NeuralNetwork( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) ) To use the model, we pass it the input data. This executes the model's ``forward``, along with some `background operations `_. Do not call ``model.forward()`` directly! Calling the model on the input returns a 10-dimensional tensor with raw predicted values for each class. We get the prediction probabilities by passing it through an instance of the ``nn.Softmax`` module. .. code-block:: default X = torch.rand(1, 28, 28, device=device) logits = model(X) pred_probab = nn.Softmax(dim=1)(logits) y_pred = pred_probab.argmax(1) print(f"Predicted class: {y_pred}") .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Predicted class: tensor([1], device='cuda:0') -------------- Model Layers ------------------------- Let's break down the layers in the FashionMNIST model. To illustrate it, we will take a sample minibatch of 3 images of size 28x28 and see what happens to it as we pass it through the network. .. code-block:: default input_image = torch.rand(3,28,28) print(input_image.size()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none torch.Size([3, 28, 28]) nn.Flatten ^^^^^^^^^^^^^^^^^^^^^^ We initialize the `nn.Flatten `_ layer to convert each 2D 28x28 image into a contiguous array of 784 pixel values ( the minibatch dimension (at dim=0) is maintained). .. code-block:: default flatten = nn.Flatten() flat_image = flatten(input_image) print(flat_image.size()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none torch.Size([3, 784]) nn.Linear ^^^^^^^^^^^^^^^^^^^^^^ The `linear layer `_ is a module that applies a linear transformation on the input using its stored weights and biases. .. code-block:: default layer1 = nn.Linear(in_features=28*28, out_features=20) hidden1 = layer1(flat_image) print(hidden1.size()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none torch.Size([3, 20]) nn.ReLU ^^^^^^^^^^^^^^^^^^^^^^ Non-linear activations are what create the complex mappings between the model's inputs and outputs. They are applied after linear transformations to introduce *nonlinearity*, helping neural networks learn a wide variety of phenomena. In this model, we use `nn.ReLU `_ between our linear layers, but there's other activations to introduce non-linearity in your model. .. code-block:: default print(f"Before ReLU: {hidden1}\n\n") hidden1 = nn.ReLU()(hidden1) print(f"After ReLU: {hidden1}") .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Before ReLU: tensor([[-8.7687e-02, 5.9901e-01, 1.4558e-01, -2.5945e-01, -8.4226e-02, -2.9029e-02, -2.9880e-02, 3.0345e-01, -1.2606e-01, 1.1985e-01, 1.4438e-01, 3.6986e-01, -4.2237e-01, 1.8449e-01, 2.6856e-01, -1.9637e-01, 4.1582e-01, 3.3586e-01, 3.5796e-01, -4.4780e-01], [ 1.0498e-01, 1.4883e-04, 2.8295e-01, -1.6375e-01, 8.0787e-02, 2.2609e-02, -4.4070e-01, 3.2335e-01, -4.9612e-02, -1.2640e-01, -7.7641e-03, 1.5634e-01, -1.5698e-01, 4.1123e-02, 3.5126e-01, 3.8443e-01, 5.5916e-01, 3.4274e-01, 4.1674e-01, -5.0670e-01], [ 1.8744e-01, 3.7170e-01, 1.2300e-01, -2.0135e-01, -1.9530e-01, -2.5365e-01, -1.3391e-01, -1.1461e-01, 3.9721e-01, 5.7025e-02, -6.0041e-03, 3.8003e-01, -8.3460e-02, -4.7511e-02, 7.6580e-02, 2.0931e-01, 4.7226e-01, 2.2087e-01, 2.3597e-01, -6.7200e-01]], grad_fn=) After ReLU: tensor([[0.0000e+00, 5.9901e-01, 1.4558e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.0345e-01, 0.0000e+00, 1.1985e-01, 1.4438e-01, 3.6986e-01, 0.0000e+00, 1.8449e-01, 2.6856e-01, 0.0000e+00, 4.1582e-01, 3.3586e-01, 3.5796e-01, 0.0000e+00], [1.0498e-01, 1.4883e-04, 2.8295e-01, 0.0000e+00, 8.0787e-02, 2.2609e-02, 0.0000e+00, 3.2335e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.5634e-01, 0.0000e+00, 4.1123e-02, 3.5126e-01, 3.8443e-01, 5.5916e-01, 3.4274e-01, 4.1674e-01, 0.0000e+00], [1.8744e-01, 3.7170e-01, 1.2300e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.9721e-01, 5.7025e-02, 0.0000e+00, 3.8003e-01, 0.0000e+00, 0.0000e+00, 7.6580e-02, 2.0931e-01, 4.7226e-01, 2.2087e-01, 2.3597e-01, 0.0000e+00]], grad_fn=) nn.Sequential ^^^^^^^^^^^^^^^^^^^^^^ `nn.Sequential `_ is an ordered container of modules. The data is passed through all the modules in the same order as defined. You can use sequential containers to put together a quick network like ``seq_modules``. .. code-block:: default seq_modules = nn.Sequential( flatten, layer1, nn.ReLU(), nn.Linear(20, 10) ) input_image = torch.rand(3,28,28) logits = seq_modules(input_image) nn.Softmax ^^^^^^^^^^^^^^^^^^^^^^ The last linear layer of the neural network returns `logits` - raw values in [-\infty, \infty] - which are passed to the `nn.Softmax `_ module. The logits are scaled to values [0, 1] representing the model's predicted probabilities for each class. ``dim`` parameter indicates the dimension along which the values must sum to 1. .. code-block:: default softmax = nn.Softmax(dim=1) pred_probab = softmax(logits) Model Parameters ------------------------- Many layers inside a neural network are *parameterized*, i.e. have associated weights and biases that are optimized during training. Subclassing ``nn.Module`` automatically tracks all fields defined inside your model object, and makes all parameters accessible using your model's ``parameters()`` or ``named_parameters()`` methods. In this example, we iterate over each parameter, and print its size and a preview of its values. .. code-block:: default print(f"Model structure: {model}\n\n") for name, param in model.named_parameters(): print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n") .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Model structure: NeuralNetwork( (flatten): Flatten(start_dim=1, end_dim=-1) (linear_relu_stack): Sequential( (0): Linear(in_features=784, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=512, bias=True) (3): ReLU() (4): Linear(in_features=512, out_features=10, bias=True) ) ) Layer: linear_relu_stack.0.weight | Size: torch.Size([512, 784]) | Values : tensor([[-0.0269, -0.0351, 0.0205, ..., 0.0327, 0.0229, 0.0116], [ 0.0056, -0.0055, 0.0073, ..., -0.0014, 0.0356, -0.0296]], device='cuda:0', grad_fn=) Layer: linear_relu_stack.0.bias | Size: torch.Size([512]) | Values : tensor([-0.0301, -0.0045], device='cuda:0', grad_fn=) Layer: linear_relu_stack.2.weight | Size: torch.Size([512, 512]) | Values : tensor([[ 0.0270, 0.0025, 0.0439, ..., 0.0229, -0.0086, 0.0118], [ 0.0041, -0.0334, 0.0303, ..., -0.0220, -0.0433, 0.0256]], device='cuda:0', grad_fn=) Layer: linear_relu_stack.2.bias | Size: torch.Size([512]) | Values : tensor([ 0.0077, -0.0281], device='cuda:0', grad_fn=) Layer: linear_relu_stack.4.weight | Size: torch.Size([10, 512]) | Values : tensor([[-0.0094, 0.0104, -0.0343, ..., -0.0044, -0.0354, -0.0139], [ 0.0074, 0.0259, -0.0347, ..., -0.0360, 0.0126, -0.0396]], device='cuda:0', grad_fn=) Layer: linear_relu_stack.4.bias | Size: torch.Size([10]) | Values : tensor([-0.0063, 0.0317], device='cuda:0', grad_fn=) -------------- Further Reading -------------- - `torch.nn API `_ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.114 seconds) .. _sphx_glr_download_beginner_basics_buildmodel_tutorial.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: buildmodel_tutorial.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: buildmodel_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_