.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/introyt/modelsyt_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_introyt_modelsyt_tutorial.py: `Introduction `_ || `Tensors `_ || `Autograd `_ || **Building Models** || `TensorBoard Support `_ || `Training Models `_ || `Model Understanding `_ Building Models with PyTorch ============================ Follow along with the video below or on `youtube `__. .. raw:: html
``torch.nn.Module`` and ``torch.nn.Parameter`` ---------------------------------------------- In this video, we’ll be discussing some of the tools PyTorch makes available for building deep learning networks. Except for ``Parameter``, the classes we discuss in this video are all subclasses of ``torch.nn.Module``. This is the PyTorch base class meant to encapsulate behaviors specific to PyTorch Models and their components. One important behavior of ``torch.nn.Module`` is registering parameters. If a particular ``Module`` subclass has learning weights, these weights are expressed as instances of ``torch.nn.Parameter``. The ``Parameter`` class is a subclass of ``torch.Tensor``, with the special behavior that when they are assigned as attributes of a ``Module``, they are added to the list of that modules parameters. These parameters may be accessed through the ``parameters()`` method on the ``Module`` class. As a simple example, here’s a very simple model with two linear layers and an activation function. We’ll create an instance of it and ask it to report on its parameters: .. GENERATED FROM PYTHON SOURCE LINES 45-82 .. code-block:: default import torch class TinyModel(torch.nn.Module): def __init__(self): super(TinyModel, self).__init__() self.linear1 = torch.nn.Linear(100, 200) self.activation = torch.nn.ReLU() self.linear2 = torch.nn.Linear(200, 10) self.softmax = torch.nn.Softmax() def forward(self, x): x = self.linear1(x) x = self.activation(x) x = self.linear2(x) x = self.softmax(x) return x tinymodel = TinyModel() print('The model:') print(tinymodel) print('\n\nJust one layer:') print(tinymodel.linear2) print('\n\nModel params:') for param in tinymodel.parameters(): print(param) print('\n\nLayer params:') for param in tinymodel.linear2.parameters(): print(param) .. rst-class:: sphx-glr-script-out .. code-block:: none The model: TinyModel( (linear1): Linear(in_features=100, out_features=200, bias=True) (activation): ReLU() (linear2): Linear(in_features=200, out_features=10, bias=True) (softmax): Softmax(dim=None) ) Just one layer: Linear(in_features=200, out_features=10, bias=True) Model params: Parameter containing: tensor([[ 0.0765, 0.0830, -0.0234, ..., -0.0337, -0.0355, -0.0968], [-0.0573, 0.0250, -0.0132, ..., -0.0060, 0.0240, 0.0280], [-0.0908, -0.0369, 0.0842, ..., -0.0078, -0.0333, -0.0324], ..., [-0.0273, -0.0162, -0.0878, ..., 0.0451, 0.0297, -0.0722], [ 0.0833, -0.0874, -0.0020, ..., -0.0215, 0.0356, 0.0405], [-0.0637, 0.0190, -0.0571, ..., -0.0874, 0.0176, 0.0712]], requires_grad=True) Parameter containing: tensor([ 0.0304, -0.0758, -0.0549, -0.0893, -0.0809, -0.0804, -0.0079, -0.0413, -0.0968, 0.0888, 0.0239, -0.0659, -0.0560, -0.0060, 0.0660, -0.0319, -0.0370, 0.0633, -0.0143, -0.0360, 0.0670, -0.0804, 0.0265, -0.0870, 0.0039, -0.0174, -0.0680, -0.0531, 0.0643, 0.0794, 0.0209, 0.0419, 0.0562, -0.0173, -0.0055, 0.0813, 0.0613, -0.0379, 0.0228, 0.0304, -0.0354, 0.0609, -0.0398, 0.0410, 0.0564, -0.0101, -0.0790, -0.0824, -0.0126, 0.0557, 0.0900, 0.0597, 0.0062, -0.0108, 0.0112, -0.0358, -0.0203, 0.0566, -0.0816, -0.0633, -0.0266, -0.0624, -0.0746, 0.0492, 0.0450, 0.0530, -0.0706, 0.0308, 0.0533, 0.0202, -0.0469, -0.0448, 0.0548, 0.0331, 0.0257, -0.0764, -0.0892, 0.0783, 0.0062, 0.0844, -0.0959, -0.0468, -0.0926, 0.0925, 0.0147, 0.0391, 0.0765, 0.0059, 0.0216, -0.0724, 0.0108, 0.0701, -0.0147, -0.0693, -0.0517, 0.0029, 0.0661, 0.0086, -0.0574, 0.0084, -0.0324, 0.0056, 0.0626, -0.0833, -0.0271, -0.0526, 0.0842, -0.0840, -0.0234, -0.0898, -0.0710, -0.0399, 0.0183, -0.0883, -0.0102, -0.0545, 0.0706, -0.0646, -0.0841, -0.0095, -0.0823, -0.0385, 0.0327, -0.0810, -0.0404, 0.0570, 0.0740, 0.0829, 0.0845, 0.0817, -0.0239, -0.0444, -0.0221, 0.0216, 0.0103, -0.0631, 0.0831, -0.0273, 0.0756, 0.0022, 0.0407, 0.0072, 0.0374, -0.0608, 0.0424, -0.0585, 0.0505, -0.0455, 0.0268, -0.0950, -0.0642, 0.0843, 0.0760, -0.0889, -0.0617, -0.0916, 0.0102, -0.0269, -0.0011, 0.0318, 0.0278, -0.0160, 0.0159, -0.0817, 0.0768, -0.0876, -0.0524, -0.0332, -0.0583, 0.0053, 0.0503, -0.0342, -0.0319, -0.0562, 0.0376, -0.0696, 0.0735, 0.0222, -0.0775, -0.0072, 0.0294, 0.0994, -0.0355, -0.0809, -0.0539, 0.0245, 0.0670, 0.0032, 0.0891, -0.0694, -0.0994, 0.0126, 0.0629, 0.0936, 0.0058, -0.0073, 0.0498, 0.0616, -0.0912, -0.0490], requires_grad=True) Parameter containing: tensor([[ 0.0504, -0.0203, -0.0573, ..., 0.0253, 0.0642, -0.0088], [-0.0078, -0.0608, -0.0626, ..., -0.0350, -0.0028, -0.0634], [-0.0317, -0.0202, -0.0593, ..., -0.0280, 0.0571, -0.0114], ..., [ 0.0582, -0.0471, -0.0236, ..., 0.0273, 0.0673, 0.0555], [ 0.0258, -0.0706, 0.0315, ..., -0.0663, -0.0133, 0.0078], [-0.0062, 0.0544, -0.0280, ..., -0.0303, -0.0326, -0.0462]], requires_grad=True) Parameter containing: tensor([ 0.0385, -0.0116, 0.0703, 0.0407, -0.0346, -0.0178, 0.0308, -0.0502, 0.0616, 0.0114], requires_grad=True) Layer params: Parameter containing: tensor([[ 0.0504, -0.0203, -0.0573, ..., 0.0253, 0.0642, -0.0088], [-0.0078, -0.0608, -0.0626, ..., -0.0350, -0.0028, -0.0634], [-0.0317, -0.0202, -0.0593, ..., -0.0280, 0.0571, -0.0114], ..., [ 0.0582, -0.0471, -0.0236, ..., 0.0273, 0.0673, 0.0555], [ 0.0258, -0.0706, 0.0315, ..., -0.0663, -0.0133, 0.0078], [-0.0062, 0.0544, -0.0280, ..., -0.0303, -0.0326, -0.0462]], requires_grad=True) Parameter containing: tensor([ 0.0385, -0.0116, 0.0703, 0.0407, -0.0346, -0.0178, 0.0308, -0.0502, 0.0616, 0.0114], requires_grad=True) .. GENERATED FROM PYTHON SOURCE LINES 83-101 This shows the fundamental structure of a PyTorch model: there is an ``__init__()`` method that defines the layers and other components of a model, and a ``forward()`` method where the computation gets done. Note that we can print the model, or any of its submodules, to learn about its structure. Common Layer Types ------------------ Linear Layers ~~~~~~~~~~~~~ The most basic type of neural network layer is a *linear* or *fully connected* layer. This is a layer where every input influences every output of the layer to a degree specified by the layer’s weights. If a model has *m* inputs and *n* outputs, the weights will be an *m* x *n* matrix. For example: .. GENERATED FROM PYTHON SOURCE LINES 101-116 .. code-block:: default lin = torch.nn.Linear(3, 2) x = torch.rand(1, 3) print('Input:') print(x) print('\n\nWeight and Bias parameters:') for param in lin.parameters(): print(param) y = lin(x) print('\n\nOutput:') print(y) .. rst-class:: sphx-glr-script-out .. code-block:: none Input: tensor([[0.8790, 0.9774, 0.2547]]) Weight and Bias parameters: Parameter containing: tensor([[ 0.1656, 0.4969, -0.4972], [-0.2035, -0.2579, -0.3780]], requires_grad=True) Parameter containing: tensor([0.3768, 0.3781], requires_grad=True) Output: tensor([[ 0.8814, -0.1492]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 117-146 If you do the matrix multiplication of ``x`` by the linear layer’s weights, and add the biases, you’ll find that you get the output vector ``y``. One other important feature to note: When we checked the weights of our layer with ``lin.weight``, it reported itself as a ``Parameter`` (which is a subclass of ``Tensor``), and let us know that it’s tracking gradients with autograd. This is a default behavior for ``Parameter`` that differs from ``Tensor``. Linear layers are used widely in deep learning models. One of the most common places you’ll see them is in classifier models, which will usually have one or more linear layers at the end, where the last layer will have *n* outputs, where *n* is the number of classes the classifier addresses. Convolutional Layers ~~~~~~~~~~~~~~~~~~~~ *Convolutional* layers are built to handle data with a high degree of spatial correlation. They are very commonly used in computer vision, where they detect close groupings of features which the compose into higher-level features. They pop up in other contexts too - for example, in NLP applications, where a word’s immediate context (that is, the other words nearby in the sequence) can affect the meaning of a sentence. We saw convolutional layers in action in LeNet5 in an earlier video: .. GENERATED FROM PYTHON SOURCE LINES 146-182 .. code-block:: default import torch.functional as F class LeNet(torch.nn.Module): def __init__(self): super(LeNet, self).__init__() # 1 input image channel (black & white), 6 output channels, 5x5 square convolution # kernel self.conv1 = torch.nn.Conv2d(1, 6, 5) self.conv2 = torch.nn.Conv2d(6, 16, 3) # an affine operation: y = Wx + b self.fc1 = torch.nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension self.fc2 = torch.nn.Linear(120, 84) self.fc3 = torch.nn.Linear(84, 10) def forward(self, x): # Max pooling over a (2, 2) window x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # If the size is a square you can only specify a single number x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = x.view(-1, self.num_flat_features(x)) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def num_flat_features(self, x): size = x.size()[1:] # all dimensions except the batch dimension num_features = 1 for s in size: num_features *= s return num_features .. GENERATED FROM PYTHON SOURCE LINES 183-248 Let’s break down what’s happening in the convolutional layers of this model. Starting with ``conv1``: - LeNet5 is meant to take in a 1x32x32 black & white image. **The first argument to a convolutional layer’s constructor is the number of input channels.** Here, it is 1. If we were building this model to look at 3-color channels, it would be 3. - A convolutional layer is like a window that scans over the image, looking for a pattern it recognizes. These patterns are called *features,* and one of the parameters of a convolutional layer is the number of features we would like it to learn. **This is the second argument to the constructor is the number of output features.** Here, we’re asking our layer to learn 6 features. - Just above, I likened the convolutional layer to a window - but how big is the window? **The third argument is the window or kernel size.** Here, the “5” means we’ve chosen a 5x5 kernel. (If you want a kernel with height different from width, you can specify a tuple for this argument - e.g., ``(3, 5)`` to get a 3x5 convolution kernel.) The output of a convolutional layer is an *activation map* - a spatial representation of the presence of features in the input tensor. ``conv1`` will give us an output tensor of 6x28x28; 6 is the number of features, and 28 is the height and width of our map. (The 28 comes from the fact that when scanning a 5-pixel window over a 32-pixel row, there are only 28 valid positions.) We then pass the output of the convolution through a ReLU activation function (more on activation functions later), then through a max pooling layer. The max pooling layer takes features near each other in the activation map and groups them together. It does this by reducing the tensor, merging every 2x2 group of cells in the output into a single cell, and assigning that cell the maximum value of the 4 cells that went into it. This gives us a lower-resolution version of the activation map, with dimensions 6x14x14. Our next convolutional layer, ``conv2``, expects 6 input channels (corresponding to the 6 features sought by the first layer), has 16 output channels, and a 3x3 kernel. It puts out a 16x12x12 activation map, which is again reduced by a max pooling layer to 16x6x6. Prior to passing this output to the linear layers, it is reshaped to a 16 \* 6 \* 6 = 576-element vector for consumption by the next layer. There are convolutional layers for addressing 1D, 2D, and 3D tensors. There are also many more optional arguments for a conv layer constructor, including stride length(e.g., only scanning every second or every third position) in the input, padding (so you can scan out to the edges of the input), and more. See the `documentation `__ for more information. Recurrent Layers ~~~~~~~~~~~~~~~~ *Recurrent neural networks* (or *RNNs)* are used for sequential data - anything from time-series measurements from a scientific instrument to natural language sentences to DNA nucleotides. An RNN does this by maintaining a *hidden state* that acts as a sort of memory for what it has seen in the sequence so far. The internal structure of an RNN layer - or its variants, the LSTM (long short-term memory) and GRU (gated recurrent unit) - is moderately complex and beyond the scope of this video, but we’ll show you what one looks like in action with an LSTM-based part-of-speech tagger (a type of classifier that tells you if a word is a noun, verb, etc.): .. GENERATED FROM PYTHON SOURCE LINES 248-272 .. code-block:: default class LSTMTagger(torch.nn.Module): def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size): super(LSTMTagger, self).__init__() self.hidden_dim = hidden_dim self.word_embeddings = torch.nn.Embedding(vocab_size, embedding_dim) # The LSTM takes word embeddings as inputs, and outputs hidden states # with dimensionality hidden_dim. self.lstm = torch.nn.LSTM(embedding_dim, hidden_dim) # The linear layer that maps from hidden state space to tag space self.hidden2tag = torch.nn.Linear(hidden_dim, tagset_size) def forward(self, sentence): embeds = self.word_embeddings(sentence) lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1)) tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1)) tag_scores = F.log_softmax(tag_space, dim=1) return tag_scores .. GENERATED FROM PYTHON SOURCE LINES 273-331 The constructor has four arguments: - ``vocab_size`` is the number of words in the input vocabulary. Each word is a one-hot vector (or unit vector) in a ``vocab_size``-dimensional space. - ``tagset_size`` is the number of tags in the output set. - ``embedding_dim`` is the size of the *embedding* space for the vocabulary. An embedding maps a vocabulary onto a low-dimensional space, where words with similar meanings are close together in the space. - ``hidden_dim`` is the size of the LSTM’s memory. The input will be a sentence with the words represented as indices of one-hot vectors. The embedding layer will then map these down to an ``embedding_dim``-dimensional space. The LSTM takes this sequence of embeddings and iterates over it, fielding an output vector of length ``hidden_dim``. The final linear layer acts as a classifier; applying ``log_softmax()`` to the output of the final layer converts the output into a normalized set of estimated probabilities that a given word maps to a given tag. If you’d like to see this network in action, check out the `Sequence Models and LSTM Networks `__ tutorial on pytorch.org. Transformers ~~~~~~~~~~~~ *Transformers* are multi-purpose networks that have taken over the state of the art in NLP with models like BERT. A discussion of transformer architecture is beyond the scope of this video, but PyTorch has a ``Transformer`` class that allows you to define the overall parameters of a transformer model - the number of attention heads, the number of encoder & decoder layers, dropout and activation functions, etc. (You can even build the BERT model from this single class, with the right parameters!) The ``torch.nn.Transformer`` class also has classes to encapsulate the individual components (``TransformerEncoder``, ``TransformerDecoder``) and subcomponents (``TransformerEncoderLayer``, ``TransformerDecoderLayer``). For details, check out the `documentation `__ on transformer classes, and the relevant `tutorial `__ on pytorch.org. Other Layers and Functions -------------------------- Data Manipulation Layers ~~~~~~~~~~~~~~~~~~~~~~~~ There are other layer types that perform important functions in models, but don’t participate in the learning process themselves. **Max pooling** (and its twin, min pooling) reduce a tensor by combining cells, and assigning the maximum value of the input cells to the output cell (we saw this). For example: .. GENERATED FROM PYTHON SOURCE LINES 331-339 .. code-block:: default my_tensor = torch.rand(1, 6, 6) print(my_tensor) maxpool_layer = torch.nn.MaxPool2d(3) print(maxpool_layer(my_tensor)) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[[0.5036, 0.6285, 0.3460, 0.7817, 0.9876, 0.0074], [0.3969, 0.7950, 0.1449, 0.4110, 0.8216, 0.6235], [0.2347, 0.3741, 0.4997, 0.9737, 0.1741, 0.4616], [0.3962, 0.9970, 0.8778, 0.4292, 0.2772, 0.9926], [0.4406, 0.3624, 0.8960, 0.6484, 0.5544, 0.9501], [0.2489, 0.8971, 0.7499, 0.1803, 0.9571, 0.6733]]]) tensor([[[0.7950, 0.9876], [0.9970, 0.9926]]]) .. GENERATED FROM PYTHON SOURCE LINES 340-349 If you look closely at the values above, you’ll see that each of the values in the maxpooled output is the maximum value of each quadrant of the 6x6 input. **Normalization layers** re-center and normalize the output of one layer before feeding it to another. Centering and scaling the intermediate tensors has a number of beneficial effects, such as letting you use higher learning rates without exploding/vanishing gradients. .. GENERATED FROM PYTHON SOURCE LINES 349-363 .. code-block:: default my_tensor = torch.rand(1, 4, 4) * 20 + 5 print(my_tensor) print(my_tensor.mean()) norm_layer = torch.nn.BatchNorm1d(4) normed_tensor = norm_layer(my_tensor) print(normed_tensor) print(normed_tensor.mean()) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[[ 7.7375, 23.5649, 6.8452, 16.3517], [19.5792, 20.3254, 6.1930, 23.7576], [23.7554, 20.8565, 18.4241, 8.5742], [22.5100, 15.6154, 13.5698, 11.8411]]]) tensor(16.2188) tensor([[[-0.8614, 1.4543, -0.9919, 0.3990], [ 0.3160, 0.4274, -1.6834, 0.9400], [ 1.0256, 0.5176, 0.0914, -1.6346], [ 1.6352, -0.0663, -0.5711, -0.9978]]], grad_fn=) tensor(3.3528e-08, grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 364-385 Running the cell above, we’ve added a large scaling factor and offset to an input tensor; you should see the input tensor’s ``mean()`` somewhere in the neighborhood of 15. After running it through the normalization layer, you can see that the values are smaller, and grouped around zero - in fact, the mean should be very small (> 1e-8). This is beneficial because many activation functions (discussed below) have their strongest gradients near 0, but sometimes suffer from vanishing or exploding gradients for inputs that drive them far away from zero. Keeping the data centered around the area of steepest gradient will tend to mean faster, better learning and higher feasible learning rates. **Dropout layers** are a tool for encouraging *sparse representations* in your model - that is, pushing it to do inference with less data. Dropout layers work by randomly setting parts of the input tensor *during training* - dropout layers are always turned off for inference. This forces the model to learn against this masked or reduced dataset. For example: .. GENERATED FROM PYTHON SOURCE LINES 385-393 .. code-block:: default my_tensor = torch.rand(1, 4, 4) dropout = torch.nn.Dropout(p=0.4) print(dropout(my_tensor)) print(dropout(my_tensor)) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[[0.8869, 0.6595, 0.2098, 0.0000], [0.5379, 0.0000, 0.0000, 0.0000], [0.1950, 0.2424, 1.3319, 0.5738], [0.5676, 0.8335, 0.0000, 0.2928]]]) tensor([[[0.8869, 0.6595, 0.2098, 0.2878], [0.5379, 0.0000, 0.4029, 0.0000], [0.0000, 0.2424, 1.3319, 0.5738], [0.0000, 0.8335, 0.9647, 0.0000]]]) .. GENERATED FROM PYTHON SOURCE LINES 394-423 Above, you can see the effect of dropout on a sample tensor. You can use the optional ``p`` argument to set the probability of an individual weight dropping out; if you don’t it defaults to 0.5. Activation Functions ~~~~~~~~~~~~~~~~~~~~ Activation functions make deep learning possible. A neural network is really a program - with many parameters - that *simulates a mathematical function*. If all we did was multiple tensors by layer weights repeatedly, we could only simulate *linear functions;* further, there would be no point to having many layers, as the whole network would reduce could be reduced to a single matrix multiplication. Inserting *non-linear* activation functions between layers is what allows a deep learning model to simulate any function, rather than just linear ones. ``torch.nn.Module`` has objects encapsulating all of the major activation functions including ReLU and its many variants, Tanh, Hardtanh, sigmoid, and more. It also includes other functions, such as Softmax, that are most useful at the output stage of a model. Loss Functions ~~~~~~~~~~~~~~ Loss functions tell us how far a model’s prediction is from the correct answer. PyTorch contains a variety of loss functions, including common MSE (mean squared error = L2 norm), Cross Entropy Loss and Negative Likelihood Loss (useful for classifiers), and others. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.030 seconds) .. _sphx_glr_download_beginner_introyt_modelsyt_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: modelsyt_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: modelsyt_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_