Shortcuts

torchvision.datasets

All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers. For example:

imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

The following datasets are available:

All the datasets have almost similar API. They all have two common arguments: transform and target_transform to transform the input and target respectively.

MNIST

class torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)[source]

MNIST Dataset.

Parameters:
  • root (string) – Root directory of dataset where processed/training.pt and processed/test.pt exist.
  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

Fashion-MNIST

class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)[source]

Fashion-MNIST Dataset.

Parameters:
  • root (string) – Root directory of dataset where processed/training.pt and processed/test.pt exist.
  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

EMNIST

class torchvision.datasets.EMNIST(root, split, **kwargs)[source]

EMNIST Dataset.

Parameters:
  • root (string) – Root directory of dataset where processed/training.pt and processed/test.pt exist.
  • split (string) – The dataset has 6 different splits: byclass, bymerge, balanced, letters, digits and mnist. This argument specifies which one to use.
  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

COCO

Note

These require the COCO API to be installed

Captions

class torchvision.datasets.CocoCaptions(root, annFile, transform=None, target_transform=None)[source]

MS Coco Captions Dataset.

Parameters:
  • root (string) – Root directory where images are downloaded to.
  • annFile (string) – Path to json annotation file.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.ToTensor
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

Example

import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
                        annFile = 'json annotation file',
                        transform=transforms.ToTensor())

print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target)

Output:

Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:Tuple (image, target). target is a list of captions for the image.
Return type:tuple

Detection

class torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None)[source]

MS Coco Detection Dataset.

Parameters:
  • root (string) – Root directory where images are downloaded to.
  • annFile (string) – Path to json annotation file.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.ToTensor
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:Tuple (image, target). target is the object returned by coco.loadAnns.
Return type:tuple

LSUN

class torchvision.datasets.LSUN(root, classes='train', transform=None, target_transform=None)[source]

LSUN dataset.

Parameters:
  • root (string) – Root directory for the database files.
  • classes (string or list) – One of {‘train’, ‘val’, ‘test’} or a list of categories to load. e,g. [‘bedroom_train’, ‘church_train’].
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:Tuple (image, target) where target is the index of the target category.
Return type:tuple

ImageFolder

class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)[source]

A generic data loader where the images are arranged in this way:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Parameters:
  • root (string) – Root directory path.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • loader – A function to load an image given its path.
__getitem__(index)
Parameters:index (int) – Index
Returns:(sample, target) where target is class_index of the target class.
Return type:tuple

DatasetFolder

class torchvision.datasets.DatasetFolder(root, loader, extensions, transform=None, target_transform=None)[source]

A generic data loader where the samples are arranged in this way:

root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext

root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Parameters:
  • root (string) – Root directory path.
  • loader (callable) – A function to load a sample given its path.
  • extensions (list[string]) – A list of allowed extensions.
  • transform (callable, optional) – A function/transform that takes in a sample and returns a transformed version. E.g, transforms.RandomCrop for images.
  • target_transform – A function/transform that takes in the target and transforms it.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(sample, target) where target is class_index of the target class.
Return type:tuple

Imagenet-12

This should simply be implemented with an ImageFolder dataset. The data is preprocessed as described here

Here is an example.

CIFAR

class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)[source]

CIFAR10 Dataset.

Parameters:
  • root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
  • train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(image, target) where target is index of the target class.
Return type:tuple
class torchvision.datasets.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)[source]

CIFAR100 Dataset.

This is a subclass of the CIFAR10 Dataset.

STL10

class torchvision.datasets.STL10(root, split='train', transform=None, target_transform=None, download=False)[source]

STL10 Dataset.

Parameters:
  • root (string) – Root directory of dataset where directory stl10_binary exists.
  • split (string) – One of {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’}. Accordingly dataset is selected.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(image, target) where target is index of the target class.
Return type:tuple

SVHN

class torchvision.datasets.SVHN(root, split='train', transform=None, target_transform=None, download=False)[source]

SVHN Dataset. Note: The SVHN dataset assigns the label 10 to the digit 0. However, in this Dataset, we assign the label 0 to the digit 0 to be compatible with PyTorch loss functions which expect the class labels to be in the range [0, C-1]

Parameters:
  • root (string) – Root directory of dataset where directory SVHN exists.
  • split (string) – One of {‘train’, ‘test’, ‘extra’}. Accordingly dataset is selected. ‘extra’ is Extra training set.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(image, target) where target is index of the target class.
Return type:tuple

PhotoTour

class torchvision.datasets.PhotoTour(root, name, train=True, transform=None, download=False)[source]

Learning Local Image Descriptors Data Dataset.

Parameters:
  • root (string) – Root directory where images are.
  • name (string) – Name of the dataset to load.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(data1, data2, matches)
Return type:tuple

SBU

class torchvision.datasets.SBU(root, transform=None, target_transform=None, download=True)[source]

SBU Captioned Photo Dataset.

Parameters:
  • root (string) – Root directory of dataset where tarball SBUCaptionedPhotoDataset.tar.gz exists.
  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(image, target) where target is a caption for the photo.
Return type:tuple

Flickr

class torchvision.datasets.Flickr8k(root, ann_file, transform=None, target_transform=None)[source]

Flickr8k Entities Dataset.

Parameters:
  • root (string) – Root directory where images are downloaded to.
  • ann_file (string) – Path to annotation file.
  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.ToTensor
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:Tuple (image, target). target is a list of captions for the image.
Return type:tuple
class torchvision.datasets.Flickr30k(root, ann_file, transform=None, target_transform=None)[source]

Flickr30k Entities Dataset.

Parameters:
  • root (string) – Root directory where images are downloaded to.
  • ann_file (string) – Path to annotation file.
  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.ToTensor
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:Tuple (image, target). target is a list of captions for the image.
Return type:tuple

VOC

class torchvision.datasets.VOCSegmentation(root, year='2012', image_set='train', download=False, transform=None, target_transform=None)[source]

Pascal VOC Segmentation Dataset.

Parameters:
  • root (string) – Root directory of the VOC Dataset.
  • year (string, optional) – The dataset year, supports years 2007 to 2012.
  • image_set (string, optional) – Select the image_set to use, train, trainval or val
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(image, target) where target is the image segmentation.
Return type:tuple
class torchvision.datasets.VOCDetection(root, year='2012', image_set='train', download=False, transform=None, target_transform=None)[source]

Pascal VOC Detection Dataset.

Parameters:
  • root (string) – Root directory of the VOC Dataset.
  • year (string, optional) – The dataset year, supports years 2007 to 2012.
  • image_set (string, optional) – Select the image_set to use, train, trainval or val
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. (default: alphabetic indexing of VOC’s 20 classes).
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, required) – A function/transform that takes in the target and transforms it.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(image, target) where target is a dictionary of the XML tree.
Return type:tuple

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources