• Tutorials >
  • Transfer Learning for Computer Vision Tutorial
Shortcuts

Transfer Learning for Computer Vision Tutorial

Author: Sasank Chilamkurthy

In this tutorial, you will learn how to train a convolutional neural network for image classification using transfer learning. You can read more about the transfer learning at cs231n notes

Quoting these notes,

In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.

These two major transfer learning scenarios look as follows:

  • Finetuning the convnet: Instead of random initializaion, we initialize the network with a pretrained network, like the one that is trained on imagenet 1000 dataset. Rest of the training looks as usual.
  • ConvNet as fixed feature extractor: Here, we will freeze the weights for all of the network except that of the final fully connected layer. This last fully connected layer is replaced with a new one with random weights and only this layer is trained.
# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # interactive mode

Load Data

We will use torchvision and torch.utils.data packages for loading the data.

The problem we’re going to solve today is to train a model to classify ants and bees. We have about 120 training images each for ants and bees. There are 75 validation images for each class. Usually, this is a very small dataset to generalize upon, if trained from scratch. Since we are using transfer learning, we should be able to generalize reasonably well.

This dataset is a very small subset of imagenet.

Note

Download the data from here and extract it to the current directory.

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Visualize a few images

Let’s visualize a few training images so as to understand the data augmentations.

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])
../_images/sphx_glr_transfer_learning_tutorial_001.png

Training the model

Now, let’s write a general function to train a model. Here, we will illustrate:

  • Scheduling the learning rate
  • Saving the best model

In the following, parameter scheduler is an LR scheduler object from torch.optim.lr_scheduler.

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

Visualizing the model predictions

Generic function to display predictions for a few images

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

Finetuning the convnet

Load a pretrained model and reset final fully connected layer.

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

Train and evaluate

It should take around 15-25 min on CPU. On GPU though, it takes less than a minute.

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)

Out:

Epoch 0/24
----------
train Loss: 0.7280 Acc: 0.6107
val Loss: 0.2969 Acc: 0.8693

Epoch 1/24
----------
train Loss: 0.5610 Acc: 0.7623
val Loss: 0.3433 Acc: 0.8497

Epoch 2/24
----------
train Loss: 0.5469 Acc: 0.8156
val Loss: 0.2532 Acc: 0.9281

Epoch 3/24
----------
train Loss: 0.2919 Acc: 0.8934
val Loss: 0.2495 Acc: 0.9085

Epoch 4/24
----------
train Loss: 0.7041 Acc: 0.7336
val Loss: 0.3091 Acc: 0.8954

Epoch 5/24
----------
train Loss: 0.4392 Acc: 0.8361
val Loss: 0.4777 Acc: 0.8301

Epoch 6/24
----------
train Loss: 0.5393 Acc: 0.8156
val Loss: 0.4883 Acc: 0.8431

Epoch 7/24
----------
train Loss: 0.3601 Acc: 0.8689
val Loss: 0.3076 Acc: 0.8824

Epoch 8/24
----------
train Loss: 0.3768 Acc: 0.8525
val Loss: 0.3054 Acc: 0.8889

Epoch 9/24
----------
train Loss: 0.2447 Acc: 0.9057
val Loss: 0.3057 Acc: 0.8954

Epoch 10/24
----------
train Loss: 0.3431 Acc: 0.8689
val Loss: 0.2478 Acc: 0.9150

Epoch 11/24
----------
train Loss: 0.3315 Acc: 0.8811
val Loss: 0.2489 Acc: 0.9020

Epoch 12/24
----------
train Loss: 0.2112 Acc: 0.9098
val Loss: 0.2523 Acc: 0.9150

Epoch 13/24
----------
train Loss: 0.3064 Acc: 0.8770
val Loss: 0.2287 Acc: 0.9281

Epoch 14/24
----------
train Loss: 0.3628 Acc: 0.8320
val Loss: 0.2683 Acc: 0.8824

Epoch 15/24
----------
train Loss: 0.2031 Acc: 0.9262
val Loss: 0.2449 Acc: 0.9150

Epoch 16/24
----------
train Loss: 0.2018 Acc: 0.9221
val Loss: 0.2356 Acc: 0.9216

Epoch 17/24
----------
train Loss: 0.3246 Acc: 0.8402
val Loss: 0.2455 Acc: 0.9150

Epoch 18/24
----------
train Loss: 0.3181 Acc: 0.8525
val Loss: 0.2533 Acc: 0.9085

Epoch 19/24
----------
train Loss: 0.2917 Acc: 0.8811
val Loss: 0.2357 Acc: 0.9281

Epoch 20/24
----------
train Loss: 0.2952 Acc: 0.8566
val Loss: 0.2451 Acc: 0.8954

Epoch 21/24
----------
train Loss: 0.3048 Acc: 0.9098
val Loss: 0.2411 Acc: 0.9216

Epoch 22/24
----------
train Loss: 0.2478 Acc: 0.8770
val Loss: 0.2277 Acc: 0.9281

Epoch 23/24
----------
train Loss: 0.2556 Acc: 0.8934
val Loss: 0.2539 Acc: 0.9020

Epoch 24/24
----------
train Loss: 0.2914 Acc: 0.8811
val Loss: 0.2411 Acc: 0.9085

Training complete in 1m 9s
Best val Acc: 0.928105
visualize_model(model_ft)
../_images/sphx_glr_transfer_learning_tutorial_002.png

ConvNet as fixed feature extractor

Here, we need to freeze all the network except the final layer. We need to set requires_grad == False to freeze the parameters so that the gradients are not computed in backward().

You can read more about this in the documentation here.

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

Train and evaluate

On CPU this will take about half the time compared to previous scenario. This is expected as gradients don’t need to be computed for most of the network. However, forward does need to be computed.

model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)

Out:

Epoch 0/24
----------
train Loss: 0.6481 Acc: 0.6475
val Loss: 0.2219 Acc: 0.9216

Epoch 1/24
----------
train Loss: 0.5332 Acc: 0.7459
val Loss: 0.2385 Acc: 0.9085

Epoch 2/24
----------
train Loss: 0.5024 Acc: 0.8156
val Loss: 0.1614 Acc: 0.9412

Epoch 3/24
----------
train Loss: 0.6070 Acc: 0.7500
val Loss: 0.2013 Acc: 0.9216

Epoch 4/24
----------
train Loss: 0.6728 Acc: 0.7336
val Loss: 0.1996 Acc: 0.9412

Epoch 5/24
----------
train Loss: 0.5280 Acc: 0.7500
val Loss: 0.1955 Acc: 0.9412

Epoch 6/24
----------
train Loss: 0.4637 Acc: 0.7992
val Loss: 0.3117 Acc: 0.8889

Epoch 7/24
----------
train Loss: 0.4219 Acc: 0.8238
val Loss: 0.1844 Acc: 0.9477

Epoch 8/24
----------
train Loss: 0.3468 Acc: 0.8484
val Loss: 0.1843 Acc: 0.9542

Epoch 9/24
----------
train Loss: 0.3442 Acc: 0.8320
val Loss: 0.1760 Acc: 0.9542

Epoch 10/24
----------
train Loss: 0.3711 Acc: 0.8156
val Loss: 0.1899 Acc: 0.9477

Epoch 11/24
----------
train Loss: 0.4115 Acc: 0.8033
val Loss: 0.1958 Acc: 0.9477

Epoch 12/24
----------
train Loss: 0.3101 Acc: 0.8648
val Loss: 0.1897 Acc: 0.9542

Epoch 13/24
----------
train Loss: 0.3244 Acc: 0.8402
val Loss: 0.1964 Acc: 0.9542

Epoch 14/24
----------
train Loss: 0.3063 Acc: 0.8852
val Loss: 0.1783 Acc: 0.9477

Epoch 15/24
----------
train Loss: 0.3705 Acc: 0.8320
val Loss: 0.1923 Acc: 0.9412

Epoch 16/24
----------
train Loss: 0.2718 Acc: 0.8770
val Loss: 0.1996 Acc: 0.9281

Epoch 17/24
----------
train Loss: 0.3274 Acc: 0.8566
val Loss: 0.1960 Acc: 0.9412

Epoch 18/24
----------
train Loss: 0.3994 Acc: 0.8115
val Loss: 0.1948 Acc: 0.9542

Epoch 19/24
----------
train Loss: 0.3682 Acc: 0.8320
val Loss: 0.2164 Acc: 0.9281

Epoch 20/24
----------
train Loss: 0.3391 Acc: 0.8402
val Loss: 0.1836 Acc: 0.9542

Epoch 21/24
----------
train Loss: 0.3218 Acc: 0.8607
val Loss: 0.1906 Acc: 0.9477

Epoch 22/24
----------
train Loss: 0.3051 Acc: 0.8566
val Loss: 0.1893 Acc: 0.9542

Epoch 23/24
----------
train Loss: 0.3166 Acc: 0.8607
val Loss: 0.1779 Acc: 0.9542

Epoch 24/24
----------
train Loss: 0.3427 Acc: 0.8361
val Loss: 0.2201 Acc: 0.9216

Training complete in 0m 34s
Best val Acc: 0.954248
visualize_model(model_conv)

plt.ioff()
plt.show()
../_images/sphx_glr_transfer_learning_tutorial_003.png

Further Learning

If you would like to learn more about the applications of transfer learning, checkout our Quantized Transfer Learning for Computer Vision Tutorial.

Total running time of the script: ( 1 minutes 54.229 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources