.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "recipes/recipes/what_is_state_dict.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_recipes_recipes_what_is_state_dict.py: What is a state_dict in PyTorch =============================== In PyTorch, the learnable parameters (i.e. weights and biases) of a ``torch.nn.Module`` model are contained in the model’s parameters (accessed with ``model.parameters()``). A ``state_dict`` is simply a Python dictionary object that maps each layer to its parameter tensor. Introduction ------------ A ``state_dict`` is an integral entity if you are interested in saving or loading models from PyTorch. Because ``state_dict`` objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers. Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s ``state_dict``. Optimizer objects (``torch.optim``) also have a ``state_dict``, which contains information about the optimizer’s state, as well as the hyperparameters used. In this recipe, we will see how ``state_dict`` is used with a simple model. Setup ----- Before we begin, we need to install ``torch`` if it isn’t already available. .. code-block:: sh pip install torch .. GENERATED FROM PYTHON SOURCE LINES 38-52 Steps ----- 1. Import all necessary libraries for loading our data 2. Define and initialize the neural network 3. Initialize the optimizer 4. Access the model and optimizer ``state_dict`` 1. Import necessary libraries for loading our data ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ For this recipe, we will use ``torch`` and its subsidiaries ``torch.nn`` and ``torch.optim``. .. GENERATED FROM PYTHON SOURCE LINES 52-59 .. code-block:: default import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim .. GENERATED FROM PYTHON SOURCE LINES 60-66 2. Define and initialize the neural network ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ For sake of example, we will create a neural network for training images. To learn more see the Defining a Neural Network recipe. .. GENERATED FROM PYTHON SOURCE LINES 66-90 .. code-block:: default class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() print(net) .. GENERATED FROM PYTHON SOURCE LINES 91-96 3. Initialize the optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We will use SGD with momentum. .. GENERATED FROM PYTHON SOURCE LINES 96-100 .. code-block:: default optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) .. GENERATED FROM PYTHON SOURCE LINES 101-107 4. Access the model and optimizer ``state_dict`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Now that we have constructed our model and optimizer, we can understand what is preserved in their respective ``state_dict`` properties. .. GENERATED FROM PYTHON SOURCE LINES 107-121 .. code-block:: default # Print model's state_dict print("Model's state_dict:") for param_tensor in net.state_dict(): print(param_tensor, "\t", net.state_dict()[param_tensor].size()) print() # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name]) .. GENERATED FROM PYTHON SOURCE LINES 122-134 This information is relevant for saving and loading the model and optimizers for future use. Congratulations! You have successfully used ``state_dict`` in PyTorch. Learn More ---------- Take a look at these other recipes to continue your learning: - `Saving and loading models for inference in PyTorch `__ - `Saving and loading a general checkpoint in PyTorch `__ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_recipes_recipes_what_is_state_dict.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: what_is_state_dict.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: what_is_state_dict.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_