.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/knowledge_distillation_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_knowledge_distillation_tutorial.py: Knowledge Distillation Tutorial =============================== **Author**: `Alexandros Chariton `_ .. GENERATED FROM PYTHON SOURCE LINES 9-33 Knowledge distillation is a technique that enables knowledge transfer from large, computationally expensive models to smaller ones without losing validity. This allows for deployment on less powerful hardware, making evaluation faster and more efficient. In this tutorial, we will run a number of experiments focused at improving the accuracy of a lightweight neural network, using a more powerful network as a teacher. The computational cost and the speed of the lightweight network will remain unaffected, our intervention only focuses on its weights, not on its forward pass. Applications of this technology can be found in devices such as drones or mobile phones. In this tutorial, we do not use any external packages as everything we need is available in ``torch`` and ``torchvision``. In this tutorial, you will learn: - How to modify model classes to extract hidden representations and use them for further calculations - How to modify regular train loops in PyTorch to include additional losses on top of, for example, cross-entropy for classification - How to improve the performance of lightweight models by using more complex models as teachers Prerequisites ~~~~~~~~~~~~~ * 1 GPU, 4GB of memory * PyTorch v2.0 or later * CIFAR-10 dataset (downloaded by the script and saved in a directory called ``/data``) .. GENERATED FROM PYTHON SOURCE LINES 33-43 .. code-block:: default import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms import torchvision.datasets as datasets # Check if GPU is available, and if not, use the CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 44-68 Loading CIFAR-10 ---------------- CIFAR-10 is a popular image dataset with ten classes. Our objective is to predict one of the following classes for each input image. .. figure:: /../_static/img/cifar10.png :align: center Example of CIFAR-10 images The input images are RGB, so they have 3 channels and are 32x32 pixels. Basically, each image is described by 3 x 32 x 32 = 3072 numbers ranging from 0 to 255. A common practice in neural networks is to normalize the input, which is done for multiple reasons, including avoiding saturation in commonly used activation functions and increasing numerical stability. Our normalization process consists of subtracting the mean and dividing by the standard deviation along each channel. The tensors "mean=[0.485, 0.456, 0.406]" and "std=[0.229, 0.224, 0.225]" were already computed, and they represent the mean and standard deviation of each channel in the predefined subset of CIFAR-10 intended to be the training set. Notice how we use these values for the test set as well, without recomputing the mean and standard deviation from scratch. This is because the network was trained on features produced by subtracting and dividing the numbers above, and we want to maintain consistency. Furthermore, in real life, we would not be able to compute the mean and standard deviation of the test set since, under our assumptions, this data would not be accessible at that point. As a closing point, we often refer to this held-out set as the validation set, and we use a separate set, called the test set, after optimizing a model's performance on the validation set. This is done to avoid selecting a model based on the greedy and biased optimization of a single metric. .. GENERATED FROM PYTHON SOURCE LINES 68-79 .. code-block:: default # Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128. transforms_cifar = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Loading the CIFAR-10 dataset: train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar) test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz 0%| | 0/170498071 [00:00`_ * `Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C., Bengio, Y.: Fitnets: Hints for thin deep nets. In: Proceedings of the International Conference on Learning Representations (2015) `_ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 7 minutes 53.174 seconds) .. _sphx_glr_download_beginner_knowledge_distillation_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: knowledge_distillation_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: knowledge_distillation_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_