Note
Click here to download the full example code
Introduction || Tensors || Autograd || Building Models || TensorBoard Support || Training Models || Model Understanding
PyTorch TensorBoard Support¶
Follow along with the video below or on youtube.
Before You Start¶
To run this tutorial, you’ll need to install PyTorch, TorchVision, Matplotlib, and TensorBoard.
With conda
:
conda install pytorch torchvision -c pytorch
conda install matplotlib tensorboard
With pip
:
pip install torch torchvision matplotlib tensorboard
Once the dependencies are installed, restart this notebook in the Python environment where you installed them.
Introduction¶
In this notebook, we’ll be training a variant of LeNet-5 against the Fashion-MNIST dataset. Fashion-MNIST is a set of image tiles depicting various garments, with ten class labels indicating the type of garment depicted.
# PyTorch model and training necessities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# Image datasets and image manipulation
import torchvision
import torchvision.transforms as transforms
# Image display
import matplotlib.pyplot as plt
import numpy as np
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
# In case you are using an environment that has TensorFlow installed,
# such as Google Colab, uncomment the following code to avoid
# a bug with saving embeddings to your TensorBoard directory
# import tensorflow as tf
# import tensorboard as tb
# tf.io.gfile = tb.compat.tensorflow_stub.io.gfile
Showing Images in TensorBoard¶
Let’s start by adding sample images from our dataset to TensorBoard:
# Gather datasets and prepare them for consumption
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Store separate training and validations splits in ./data
training_set = torchvision.datasets.FashionMNIST('./data',
download=True,
train=True,
transform=transform)
validation_set = torchvision.datasets.FashionMNIST('./data',
download=True,
train=False,
transform=transform)
training_loader = torch.utils.data.DataLoader(training_set,
batch_size=4,
shuffle=True,
num_workers=2)
validation_loader = torch.utils.data.DataLoader(validation_set,
batch_size=4,
shuffle=False,
num_workers=2)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# Extract a batch of 4 images
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/26421880 [00:00<?, ?it/s]
0%| | 65536/26421880 [00:00<01:12, 363682.19it/s]
1%| | 196608/26421880 [00:00<00:45, 575687.20it/s]
1%|1 | 327680/26421880 [00:00<00:40, 643658.78it/s]
2%|1 | 491520/26421880 [00:00<00:34, 746900.84it/s]
2%|2 | 655360/26421880 [00:00<00:32, 804837.27it/s]
3%|3 | 851968/26421880 [00:01<00:28, 899472.86it/s]
4%|4 | 1081344/26421880 [00:01<00:24, 1018721.26it/s]
5%|4 | 1310720/26421880 [00:01<00:22, 1098365.33it/s]
6%|5 | 1572864/26421880 [00:01<00:20, 1207097.30it/s]
7%|6 | 1835008/26421880 [00:01<00:19, 1284755.44it/s]
8%|8 | 2162688/26421880 [00:01<00:16, 1444806.75it/s]
10%|9 | 2523136/26421880 [00:02<00:14, 1610158.98it/s]
11%|# | 2883584/26421880 [00:02<00:13, 1728353.06it/s]
13%|#2 | 3309568/26421880 [00:02<00:12, 1915957.35it/s]
14%|#4 | 3768320/26421880 [00:02<00:10, 2104852.32it/s]
16%|#6 | 4292608/26421880 [00:02<00:09, 2339970.95it/s]
18%|#8 | 4849664/26421880 [00:03<00:08, 2566023.63it/s]
21%|## | 5472256/26421880 [00:03<00:07, 2829119.27it/s]
23%|##3 | 6160384/26421880 [00:03<00:06, 3121905.42it/s]
26%|##6 | 6914048/26421880 [00:03<00:05, 3435978.46it/s]
29%|##9 | 7733248/26421880 [00:03<00:04, 4066117.34it/s]
31%|###1 | 8224768/26421880 [00:03<00:04, 3947185.31it/s]
35%|###4 | 9142272/26421880 [00:04<00:03, 4707061.66it/s]
37%|###6 | 9699328/26421880 [00:04<00:03, 4543996.57it/s]
41%|#### | 10780672/26421880 [00:04<00:02, 5508758.91it/s]
43%|####3 | 11403264/26421880 [00:04<00:02, 5255324.11it/s]
48%|####7 | 12681216/26421880 [00:04<00:02, 6453864.94it/s]
51%|##### | 13402112/26421880 [00:04<00:02, 6125337.37it/s]
56%|#####6 | 14876672/26421880 [00:04<00:01, 7133503.19it/s]
59%|#####9 | 15695872/26421880 [00:04<00:01, 7207362.77it/s]
66%|######5 | 17399808/26421880 [00:05<00:01, 8346119.72it/s]
69%|######9 | 18350080/26421880 [00:05<00:00, 8413250.18it/s]
77%|#######6 | 20316160/26421880 [00:05<00:00, 10245892.15it/s]
81%|######## | 21397504/26421880 [00:05<00:00, 10015263.81it/s]
85%|########5 | 22544384/26421880 [00:05<00:00, 10130402.38it/s]
94%|#########4| 24936448/26421880 [00:05<00:00, 11821453.60it/s]
100%|#########9| 26312704/26421880 [00:05<00:00, 11960853.52it/s]
100%|##########| 26421880/26421880 [00:05<00:00, 4491889.18it/s]
Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 329708.21it/s]
Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/4422102 [00:00<?, ?it/s]
1%|1 | 65536/4422102 [00:00<00:12, 361995.82it/s]
5%|5 | 229376/4422102 [00:00<00:06, 681340.03it/s]
21%|## | 917504/4422102 [00:00<00:01, 2105495.24it/s]
62%|######1 | 2719744/4422102 [00:00<00:00, 5200098.97it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 5406017.36it/s]
Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 31475622.44it/s]
Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Above, we used TorchVision and Matplotlib to create a visual grid of a
minibatch of our input data. Below, we use the add_image()
call on
SummaryWriter
to log the image for consumption by TensorBoard, and
we also call flush()
to make sure it’s written to disk right away.
# Default log_dir argument is "runs" - but it's good to be specific
# torch.utils.tensorboard.SummaryWriter is imported above
writer = SummaryWriter('runs/fashion_mnist_experiment_1')
# Write image data to TensorBoard log dir
writer.add_image('Four Fashion-MNIST Images', img_grid)
writer.flush()
# To view, start TensorBoard on the command line with:
# tensorboard --logdir=runs
# ...and open a browser tab to http://localhost:6006/
If you start TensorBoard at the command line and open it in a new browser tab (usually at localhost:6006), you should see the image grid under the IMAGES tab.
Graphing Scalars to Visualize Training¶
TensorBoard is useful for tracking the progress and efficacy of your training. Below, we’ll run a training loop, track some metrics, and save the data for TensorBoard’s consumption.
Let’s define a model to categorize our image tiles, and an optimizer and loss function for training:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
Now let’s train a single epoch, and evaluate the training vs. validation set losses every 1000 batches:
print(len(validation_loader))
for epoch in range(1): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(training_loader, 0):
# basic training loop
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 1000 == 999: # Every 1000 mini-batches...
print('Batch {}'.format(i + 1))
# Check against the validation set
running_vloss = 0.0
# In evaluation mode some model specific operations can be omitted eg. dropout layer
net.train(False) # Switching to evaluation mode, eg. turning off regularisation
for j, vdata in enumerate(validation_loader, 0):
vinputs, vlabels = vdata
voutputs = net(vinputs)
vloss = criterion(voutputs, vlabels)
running_vloss += vloss.item()
net.train(True) # Switching back to training mode, eg. turning on regularisation
avg_loss = running_loss / 1000
avg_vloss = running_vloss / len(validation_loader)
# Log the running loss averaged per batch
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch * len(training_loader) + i)
running_loss = 0.0
print('Finished Training')
writer.flush()
2500
Batch 1000
Batch 2000
Batch 3000
Batch 4000
Batch 5000
Batch 6000
Batch 7000
Batch 8000
Batch 9000
Batch 10000
Batch 11000
Batch 12000
Batch 13000
Batch 14000
Batch 15000
Finished Training
Switch to your open TensorBoard and have a look at the SCALARS tab.
Visualizing Your Model¶
TensorBoard can also be used to examine the data flow within your model.
To do this, call the add_graph()
method with a model and sample
input:
# Again, grab a single mini-batch of images
dataiter = iter(training_loader)
images, labels = next(dataiter)
# add_graph() will trace the sample input through your model,
# and render it as a graph.
writer.add_graph(net, images)
writer.flush()
When you switch over to TensorBoard, you should see a GRAPHS tab. Double-click the “NET” node to see the layers and data flow within your model.
Visualizing Your Dataset with Embeddings¶
The 28-by-28 image tiles we’re using can be modeled as 784-dimensional
vectors (28 * 28 = 784). It can be instructive to project this to a
lower-dimensional representation. The add_embedding()
method will
project a set of data onto the three dimensions with highest variance,
and display them as an interactive 3D chart. The add_embedding()
method does this automatically by projecting to the three dimensions
with highest variance.
Below, we’ll take a sample of our data, and generate such an embedding:
# Select a random subset of data and corresponding labels
def select_n_random(data, labels, n=100):
assert len(data) == len(labels)
perm = torch.randperm(len(data))
return data[perm][:n], labels[perm][:n]
# Extract a random subset of data
images, labels = select_n_random(training_set.data, training_set.targets)
# get the class labels for each image
class_labels = [classes[label] for label in labels]
# log embeddings
features = images.view(-1, 28 * 28)
writer.add_embedding(features,
metadata=class_labels,
label_img=images.unsqueeze(1))
writer.flush()
writer.close()
Now if you switch to TensorBoard and select the PROJECTOR tab, you should see a 3D representation of the projection. You can rotate and zoom the model. Examine it at large and small scales, and see whether you can spot patterns in the projected data and the clustering of labels.
For better visibility, it’s recommended to:
Select “label” from the “Color by” drop-down on the left.
Toggle the Night Mode icon along the top to place the light-colored images on a dark background.
Other Resources¶
For more information, have a look at:
PyTorch documentation on torch.utils.tensorboard.SummaryWriter
Tensorboard tutorial content in the PyTorch.org Tutorials
For more information about TensorBoard, see the TensorBoard documentation
Total running time of the script: ( 2 minutes 42.786 seconds)