{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n", "# https://pytorch.org/tutorials/beginner/colab\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Writing Custom Datasets, DataLoaders and Transforms\n", "===================================================\n", "\n", "**Author**: [Sasank Chilamkurthy](https://chsasank.github.io)\n", "\n", "A lot of effort in solving any machine learning problem goes into\n", "preparing the data. PyTorch provides many tools to make data loading\n", "easy and hopefully, to make your code more readable. In this tutorial,\n", "we will see how to load and preprocess/augment data from a non trivial\n", "dataset.\n", "\n", "To run this tutorial, please make sure the following packages are\n", "installed:\n", "\n", "- `scikit-image`: For image io and transforms\n", "- `pandas`: For easier csv parsing\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import os\n", "import torch\n", "import pandas as pd\n", "from skimage import io, transform\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from torch.utils.data import Dataset, DataLoader\n", "from torchvision import transforms, utils\n", "\n", "# Ignore warnings\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "plt.ion() # interactive mode" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset we are going to deal with is that of facial pose. This means\n", "that a face is annotated like this:\n", "\n", "![](https://pytorch.org/tutorials/_static/img/landmarked_face2.png){width=\"400px\"}\n", "\n", "Over all, 68 different landmark points are annotated for each face.\n", "\n", "
NOTE:
\n", "
\n", "

Download the dataset from hereso that the images are in a directory named 'data/faces/'.This dataset was actuallygenerated by applying excellent dlib's poseestimationon a few images from imagenet tagged as 'face'.

\n", "
\n", "\n", "Dataset comes with a `.csv` file with annotations which looks like this:\n", "\n", "``` {.sourceCode .sh}\n", "image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y\n", "0805personali01.jpg,27,83,27,98, ... 84,134\n", "1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312\n", "```\n", "\n", "Let\\'s take a single image name and its annotations from the CSV, in\n", "this case row index number 65 for person-7.jpg just as an example. Read\n", "it, store the image name in `img_name` and store its annotations in an\n", "(L, 2) array `landmarks` where L is the number of landmarks in that row.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')\n", "\n", "n = 65\n", "img_name = landmarks_frame.iloc[n, 0]\n", "landmarks = landmarks_frame.iloc[n, 1:]\n", "landmarks = np.asarray(landmarks, dtype=float).reshape(-1, 2)\n", "\n", "print('Image name: {}'.format(img_name))\n", "print('Landmarks shape: {}'.format(landmarks.shape))\n", "print('First 4 Landmarks: {}'.format(landmarks[:4]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s write a simple helper function to show an image and its landmarks\n", "and use it to show a sample.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def show_landmarks(image, landmarks):\n", " \"\"\"Show image with landmarks\"\"\"\n", " plt.imshow(image)\n", " plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')\n", " plt.pause(0.001) # pause a bit so that plots are updated\n", "\n", "plt.figure()\n", "show_landmarks(io.imread(os.path.join('data/faces/', img_name)),\n", " landmarks)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dataset class\n", "=============\n", "\n", "`torch.utils.data.Dataset` is an abstract class representing a dataset.\n", "Your custom dataset should inherit `Dataset` and override the following\n", "methods:\n", "\n", "- `__len__` so that `len(dataset)` returns the size of the dataset.\n", "- `__getitem__` to support the indexing such that `dataset[i]` can be\n", " used to get $i$th sample.\n", "\n", "Let\\'s create a dataset class for our face landmarks dataset. We will\n", "read the csv in `__init__` but leave the reading of images to\n", "`__getitem__`. This is memory efficient because all the images are not\n", "stored in the memory at once but read as required.\n", "\n", "Sample of our dataset will be a dict\n", "`{'image': image, 'landmarks': landmarks}`. Our dataset will take an\n", "optional argument `transform` so that any required processing can be\n", "applied on the sample. We will see the usefulness of `transform` in the\n", "next section.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class FaceLandmarksDataset(Dataset):\n", " \"\"\"Face Landmarks dataset.\"\"\"\n", "\n", " def __init__(self, csv_file, root_dir, transform=None):\n", " \"\"\"\n", " Arguments:\n", " csv_file (string): Path to the csv file with annotations.\n", " root_dir (string): Directory with all the images.\n", " transform (callable, optional): Optional transform to be applied\n", " on a sample.\n", " \"\"\"\n", " self.landmarks_frame = pd.read_csv(csv_file)\n", " self.root_dir = root_dir\n", " self.transform = transform\n", "\n", " def __len__(self):\n", " return len(self.landmarks_frame)\n", "\n", " def __getitem__(self, idx):\n", " if torch.is_tensor(idx):\n", " idx = idx.tolist()\n", "\n", " img_name = os.path.join(self.root_dir,\n", " self.landmarks_frame.iloc[idx, 0])\n", " image = io.imread(img_name)\n", " landmarks = self.landmarks_frame.iloc[idx, 1:]\n", " landmarks = np.array([landmarks], dtype=float).reshape(-1, 2)\n", " sample = {'image': image, 'landmarks': landmarks}\n", "\n", " if self.transform:\n", " sample = self.transform(sample)\n", "\n", " return sample" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s instantiate this class and iterate through the data samples. We\n", "will print the sizes of first 4 samples and show their landmarks.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',\n", " root_dir='data/faces/')\n", "\n", "fig = plt.figure()\n", "\n", "for i, sample in enumerate(face_dataset):\n", " print(i, sample['image'].shape, sample['landmarks'].shape)\n", "\n", " ax = plt.subplot(1, 4, i + 1)\n", " plt.tight_layout()\n", " ax.set_title('Sample #{}'.format(i))\n", " ax.axis('off')\n", " show_landmarks(**sample)\n", "\n", " if i == 3:\n", " plt.show()\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transforms\n", "==========\n", "\n", "One issue we can see from the above is that the samples are not of the\n", "same size. Most neural networks expect the images of a fixed size.\n", "Therefore, we will need to write some preprocessing code. Let\\'s create\n", "three transforms:\n", "\n", "- `Rescale`: to scale the image\n", "- `RandomCrop`: to crop from image randomly. This is data\n", " augmentation.\n", "- `ToTensor`: to convert the numpy images to torch images (we need to\n", " swap axes).\n", "\n", "We will write them as callable classes instead of simple functions so\n", "that parameters of the transform need not be passed every time it\\'s\n", "called. For this, we just need to implement `__call__` method and if\n", "required, `__init__` method. We can then use a transform like this:\n", "\n", "``` {.sourceCode .python}\n", "tsfm = Transform(params)\n", "transformed_sample = tsfm(sample)\n", "```\n", "\n", "Observe below how these transforms had to be applied both on the image\n", "and landmarks.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Rescale(object):\n", " \"\"\"Rescale the image in a sample to a given size.\n", "\n", " Args:\n", " output_size (tuple or int): Desired output size. If tuple, output is\n", " matched to output_size. If int, smaller of image edges is matched\n", " to output_size keeping aspect ratio the same.\n", " \"\"\"\n", "\n", " def __init__(self, output_size):\n", " assert isinstance(output_size, (int, tuple))\n", " self.output_size = output_size\n", "\n", " def __call__(self, sample):\n", " image, landmarks = sample['image'], sample['landmarks']\n", "\n", " h, w = image.shape[:2]\n", " if isinstance(self.output_size, int):\n", " if h > w:\n", " new_h, new_w = self.output_size * h / w, self.output_size\n", " else:\n", " new_h, new_w = self.output_size, self.output_size * w / h\n", " else:\n", " new_h, new_w = self.output_size\n", "\n", " new_h, new_w = int(new_h), int(new_w)\n", "\n", " img = transform.resize(image, (new_h, new_w))\n", "\n", " # h and w are swapped for landmarks because for images,\n", " # x and y axes are axis 1 and 0 respectively\n", " landmarks = landmarks * [new_w / w, new_h / h]\n", "\n", " return {'image': img, 'landmarks': landmarks}\n", "\n", "\n", "class RandomCrop(object):\n", " \"\"\"Crop randomly the image in a sample.\n", "\n", " Args:\n", " output_size (tuple or int): Desired output size. If int, square crop\n", " is made.\n", " \"\"\"\n", "\n", " def __init__(self, output_size):\n", " assert isinstance(output_size, (int, tuple))\n", " if isinstance(output_size, int):\n", " self.output_size = (output_size, output_size)\n", " else:\n", " assert len(output_size) == 2\n", " self.output_size = output_size\n", "\n", " def __call__(self, sample):\n", " image, landmarks = sample['image'], sample['landmarks']\n", "\n", " h, w = image.shape[:2]\n", " new_h, new_w = self.output_size\n", "\n", " top = np.random.randint(0, h - new_h + 1)\n", " left = np.random.randint(0, w - new_w + 1)\n", "\n", " image = image[top: top + new_h,\n", " left: left + new_w]\n", "\n", " landmarks = landmarks - [left, top]\n", "\n", " return {'image': image, 'landmarks': landmarks}\n", "\n", "\n", "class ToTensor(object):\n", " \"\"\"Convert ndarrays in sample to Tensors.\"\"\"\n", "\n", " def __call__(self, sample):\n", " image, landmarks = sample['image'], sample['landmarks']\n", "\n", " # swap color axis because\n", " # numpy image: H x W x C\n", " # torch image: C x H x W\n", " image = image.transpose((2, 0, 1))\n", " return {'image': torch.from_numpy(image),\n", " 'landmarks': torch.from_numpy(landmarks)}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
NOTE:
\n", "
\n", "

In the example above, uses an external library's random number generator(in this case, Numpy's ). This can result in unexpected behavior with (see here).In practice, it is safer to stick to PyTorch's random number generator, e.g. by using instead.

\n", "
\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compose transforms\n", "==================\n", "\n", "Now, we apply the transforms on a sample.\n", "\n", "Let\\'s say we want to rescale the shorter side of the image to 256 and\n", "then randomly crop a square of size 224 from it. i.e, we want to compose\n", "`Rescale` and `RandomCrop` transforms. `torchvision.transforms.Compose`\n", "is a simple callable class which allows us to do this.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "scale = Rescale(256)\n", "crop = RandomCrop(128)\n", "composed = transforms.Compose([Rescale(256),\n", " RandomCrop(224)])\n", "\n", "# Apply each of the above transforms on sample.\n", "fig = plt.figure()\n", "sample = face_dataset[65]\n", "for i, tsfrm in enumerate([scale, crop, composed]):\n", " transformed_sample = tsfrm(sample)\n", "\n", " ax = plt.subplot(1, 3, i + 1)\n", " plt.tight_layout()\n", " ax.set_title(type(tsfrm).__name__)\n", " show_landmarks(**transformed_sample)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Iterating through the dataset\n", "=============================\n", "\n", "Let\\'s put this all together to create a dataset with composed\n", "transforms. To summarize, every time this dataset is sampled:\n", "\n", "- An image is read from the file on the fly\n", "- Transforms are applied on the read image\n", "- Since one of the transforms is random, data is augmented on sampling\n", "\n", "We can iterate over the created dataset with a `for i in range` loop as\n", "before.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',\n", " root_dir='data/faces/',\n", " transform=transforms.Compose([\n", " Rescale(256),\n", " RandomCrop(224),\n", " ToTensor()\n", " ]))\n", "\n", "for i, sample in enumerate(transformed_dataset):\n", " print(i, sample['image'].size(), sample['landmarks'].size())\n", "\n", " if i == 3:\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, we are losing a lot of features by using a simple `for` loop to\n", "iterate over the data. In particular, we are missing out on:\n", "\n", "- Batching the data\n", "- Shuffling the data\n", "- Load the data in parallel using `multiprocessing` workers.\n", "\n", "`torch.utils.data.DataLoader` is an iterator which provides all these\n", "features. Parameters used below should be clear. One parameter of\n", "interest is `collate_fn`. You can specify how exactly the samples need\n", "to be batched using `collate_fn`. However, default collate should work\n", "fine for most use cases.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dataloader = DataLoader(transformed_dataset, batch_size=4,\n", " shuffle=True, num_workers=0)\n", "\n", "\n", "# Helper function to show a batch\n", "def show_landmarks_batch(sample_batched):\n", " \"\"\"Show image with landmarks for a batch of samples.\"\"\"\n", " images_batch, landmarks_batch = \\\n", " sample_batched['image'], sample_batched['landmarks']\n", " batch_size = len(images_batch)\n", " im_size = images_batch.size(2)\n", " grid_border_size = 2\n", "\n", " grid = utils.make_grid(images_batch)\n", " plt.imshow(grid.numpy().transpose((1, 2, 0)))\n", "\n", " for i in range(batch_size):\n", " plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,\n", " landmarks_batch[i, :, 1].numpy() + grid_border_size,\n", " s=10, marker='.', c='r')\n", "\n", " plt.title('Batch from dataloader')\n", "\n", "# if you are using Windows, uncomment the next line and indent the for loop.\n", "# you might need to go back and change ``num_workers`` to 0.\n", "\n", "# if __name__ == '__main__':\n", "for i_batch, sample_batched in enumerate(dataloader):\n", " print(i_batch, sample_batched['image'].size(),\n", " sample_batched['landmarks'].size())\n", "\n", " # observe 4th batch and stop.\n", " if i_batch == 3:\n", " plt.figure()\n", " show_landmarks_batch(sample_batched)\n", " plt.axis('off')\n", " plt.ioff()\n", " plt.show()\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Afterword: torchvision\n", "======================\n", "\n", "In this tutorial, we have seen how to write and use datasets, transforms\n", "and dataloader. `torchvision` package provides some common datasets and\n", "transforms. You might not even have to write custom classes. One of the\n", "more generic datasets available in torchvision is `ImageFolder`. It\n", "assumes that images are organized in the following way:\n", "\n", "``` {.sourceCode .sh}\n", "root/ants/xxx.png\n", "root/ants/xxy.jpeg\n", "root/ants/xxz.png\n", ".\n", ".\n", ".\n", "root/bees/123.jpg\n", "root/bees/nsdf3.png\n", "root/bees/asd932_.png\n", "```\n", "\n", "where \\'ants\\', \\'bees\\' etc. are class labels. Similarly generic\n", "transforms which operate on `PIL.Image` like `RandomHorizontalFlip`,\n", "`Scale`, are also available. You can use these to write a dataloader\n", "like this:\n", "\n", "``` {.sourceCode .pytorch}\n", "import torch\n", "from torchvision import transforms, datasets\n", "\n", "data_transform = transforms.Compose([\n", " transforms.RandomSizedCrop(224),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", " std=[0.229, 0.224, 0.225])\n", " ])\n", "hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',\n", " transform=data_transform)\n", "dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,\n", " batch_size=4, shuffle=True,\n", " num_workers=4)\n", "```\n", "\n", "For an example with training code, please see\n", "`transfer_learning_tutorial`{.interpreted-text role=\"doc\"}.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 0 }