What is a state_dict in PyTorch¶
In PyTorch, the learnable parameters (i.e. weights and biases) of a
torch.nn.Module model are contained in the model’s parameters
state_dict is simply a
Python dictionary object that maps each layer to its parameter tensor.
state_dict is an integral entity if you are interested in saving
or loading models from PyTorch.
state_dict objects are Python dictionaries, they can be
easily saved, updated, altered, and restored, adding a great deal of
modularity to PyTorch models and optimizers.
Note that only layers with learnable parameters (convolutional layers,
linear layers, etc.) and registered buffers (batchnorm’s running_mean)
have entries in the model’s
state_dict. Optimizer objects
torch.optim) also have a
state_dict, which contains information
about the optimizer’s state, as well as the hyperparameters used.
In this recipe, we will see how
state_dict is used with a simple
Before we begin, we need to install
torch if it isn’t already
pip install torch
- Import all necessary libraries for loading our data
- Define and intialize the neural network
- Initialize the optimizer
- Access the model and optimizer
1. Import necessary libraries for loading our data¶
For this recipe, we will use
torch and its subsidiaries
import torch import torch.nn as nn import torch.optim as optim
2. Define and intialize the neural network¶
For sake of example, we will create a neural network for training images. To learn more see the Defining a Neural Network recipe.
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() print(net)
3. Initialize the optimizer¶
We will use SGD with momentum.
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4. Access the model and optimizer
Now that we have constructed our model and optimizer, we can understand
what is preserved in their respective
# Print model's state_dict print("Model's state_dict:") for param_tensor in net.state_dict(): print(param_tensor, "\t", net.state_dict()[param_tensor].size()) print() # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name])
This information is relevant for saving and loading the model and optimizers for future use.
Congratulations! You have successfully used
state_dict in PyTorch.
Take a look at these other recipes to continue your learning:
- Saving and loading models for inference in PyTorch
- Saving and loading a general checkpoint in PyTorch
Total running time of the script: ( 0 minutes 0.000 seconds)