.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "recipes/recipes/saving_and_loading_a_general_checkpoint.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_recipes_recipes_saving_and_loading_a_general_checkpoint.py: Saving and loading a general checkpoint in PyTorch ================================================== Saving and loading a general checkpoint model for inference or resuming training can be helpful for picking up where you last left off. When saving a general checkpoint, you must save more than just the model’s state_dict. It is important to also save the optimizer’s state_dict, as this contains buffers and parameters that are updated as the model trains. Other items that you may want to save are the epoch you left off on, the latest recorded training loss, external ``torch.nn.Embedding`` layers, and more, based on your own algorithm. Introduction ------------ To save multiple checkpoints, you must organize them in a dictionary and use ``torch.save()`` to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the ``.tar`` file extension. To load the items, first initialize the model and optimizer, then load the dictionary locally using torch.load(). From here, you can easily access the saved items by simply querying the dictionary as you would expect. In this recipe, we will explore how to save and load multiple checkpoints. Setup ----- Before we begin, we need to install ``torch`` if it isn’t already available. :: pip install torch .. GENERATED FROM PYTHON SOURCE LINES 41-56 Steps ----- 1. Import all necessary libraries for loading our data 2. Define and initialize the neural network 3. Initialize the optimizer 4. Save the general checkpoint 5. Load the general checkpoint 1. Import necessary libraries for loading our data ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ For this recipe, we will use ``torch`` and its subsidiaries ``torch.nn`` and ``torch.optim``. .. GENERATED FROM PYTHON SOURCE LINES 56-62 .. code-block:: default import torch import torch.nn as nn import torch.optim as optim .. GENERATED FROM PYTHON SOURCE LINES 63-69 2. Define and initialize the neural network ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ For sake of example, we will create a neural network for training images. To learn more see the Defining a Neural Network recipe. .. GENERATED FROM PYTHON SOURCE LINES 69-93 .. code-block:: default class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() print(net) .. GENERATED FROM PYTHON SOURCE LINES 94-99 3. Initialize the optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We will use SGD with momentum. .. GENERATED FROM PYTHON SOURCE LINES 99-103 .. code-block:: default optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) .. GENERATED FROM PYTHON SOURCE LINES 104-109 4. Save the general checkpoint ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Collect all relevant information and build your dictionary. .. GENERATED FROM PYTHON SOURCE LINES 109-123 .. code-block:: default # Additional information EPOCH = 5 PATH = "model.pt" LOSS = 0.4 torch.save({ 'epoch': EPOCH, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': LOSS, }, PATH) .. GENERATED FROM PYTHON SOURCE LINES 124-130 5. Load the general checkpoint ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Remember to first initialize the model and optimizer, then load the dictionary locally. .. GENERATED FROM PYTHON SOURCE LINES 130-145 .. code-block:: default model = Net() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train() .. GENERATED FROM PYTHON SOURCE LINES 146-156 You must call ``model.eval()`` to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. If you wish to resuming training, call ``model.train()`` to ensure these layers are in training mode. Congratulations! You have successfully saved and loaded a general checkpoint for inference and/or resuming training in PyTorch. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_recipes_recipes_saving_and_loading_a_general_checkpoint.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: saving_and_loading_a_general_checkpoint.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: saving_and_loading_a_general_checkpoint.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_