torchvision.datasets
All datasets are subclasses of torch.utils.data.Dataset
i.e, they have __getitem__
and __len__
methods implemented.
Hence, they can all be passed to a torch.utils.data.DataLoader
which can load multiple samples parallelly using torch.multiprocessing
workers.
For example:
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
The following datasets are available:
Datasets
All the datasets have almost similar API. They all have two common arguments:
transform
and target_transform
to transform the input and target respectively.
MNIST
-
class
torchvision.datasets.
MNIST
(root, train=True, transform=None, target_transform=None, download=False) MNIST Dataset.
Parameters: - root (string) – Root directory of dataset where
MNIST/processed/training.pt
andMNIST/processed/test.pt
exist. - train (bool, optional) – If True, creates dataset from
training.pt
, otherwise fromtest.pt
. - download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- root (string) – Root directory of dataset where
Fashion-MNIST
-
class
torchvision.datasets.
FashionMNIST
(root, train=True, transform=None, target_transform=None, download=False) Fashion-MNIST Dataset.
Parameters: - root (string) – Root directory of dataset where
Fashion-MNIST/processed/training.pt
andFashion-MNIST/processed/test.pt
exist. - train (bool, optional) – If True, creates dataset from
training.pt
, otherwise fromtest.pt
. - download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- root (string) – Root directory of dataset where
KMNIST
-
class
torchvision.datasets.
KMNIST
(root, train=True, transform=None, target_transform=None, download=False) Kuzushiji-MNIST Dataset.
Parameters: - root (string) – Root directory of dataset where
KMNIST/processed/training.pt
andKMNIST/processed/test.pt
exist. - train (bool, optional) – If True, creates dataset from
training.pt
, otherwise fromtest.pt
. - download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- root (string) – Root directory of dataset where
EMNIST
-
class
torchvision.datasets.
EMNIST
(root, split, **kwargs) EMNIST Dataset.
Parameters: - root (string) – Root directory of dataset where
EMNIST/processed/training.pt
andEMNIST/processed/test.pt
exist. - split (string) – The dataset has 6 different splits:
byclass
,bymerge
,balanced
,letters
,digits
andmnist
. This argument specifies which one to use. - train (bool, optional) – If True, creates dataset from
training.pt
, otherwise fromtest.pt
. - download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- root (string) – Root directory of dataset where
FakeData
-
class
torchvision.datasets.
FakeData
(size=1000, image_size=(3, 224, 224), num_classes=10, transform=None, target_transform=None, random_offset=0) A fake dataset that returns randomly generated images and returns them as PIL images
Parameters: - size (int, optional) – Size of the dataset. Default: 1000 images
- image_size (tuple, optional) – Size if the returned images. Default: (3, 224, 224)
- num_classes (int, optional) – Number of classes in the datset. Default: 10
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- random_offset (int) – Offsets the index-based random seed used to generate each image. Default: 0
COCO
Note
These require the COCO API to be installed
Captions
-
class
torchvision.datasets.
CocoCaptions
(root, annFile, transform=None, target_transform=None) MS Coco Captions Dataset.
Parameters: - root (string) – Root directory where images are downloaded to.
- annFile (string) – Path to json annotation file.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.ToTensor
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
Example
import torchvision.datasets as dset import torchvision.transforms as transforms cap = dset.CocoCaptions(root = 'dir where images are', annFile = 'json annotation file', transform=transforms.ToTensor()) print('Number of samples: ', len(cap)) img, target = cap[3] # load 4th sample print("Image Size: ", img.size()) print(target)
Output:
Number of samples: 82783 Image Size: (3L, 427L, 640L) [u'A plane emitting smoke stream flying over a mountain.', u'A plane darts across a bright blue sky behind a mountain covered in snow', u'A plane leaves a contrail above the snowy mountain top.', u'A mountain that has a plane flying overheard in the distance.', u'A mountain view with a plume of smoke in the background']
Detection
-
class
torchvision.datasets.
CocoDetection
(root, annFile, transform=None, target_transform=None) MS Coco Detection Dataset.
Parameters: - root (string) – Root directory where images are downloaded to.
- annFile (string) – Path to json annotation file.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.ToTensor
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
LSUN
-
class
torchvision.datasets.
LSUN
(root, classes='train', transform=None, target_transform=None) LSUN dataset.
Parameters: - root (string) – Root directory for the database files.
- classes (string or list) – One of {‘train’, ‘val’, ‘test’} or a list of categories to load. e,g. [‘bedroom_train’, ‘church_train’].
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
ImageFolder
-
class
torchvision.datasets.
ImageFolder
(root, transform=None, target_transform=None, loader=<function default_loader>) A generic data loader where the images are arranged in this way:
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
Parameters: - root (string) – Root directory path.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- loader – A function to load an image given its path.
DatasetFolder
-
class
torchvision.datasets.
DatasetFolder
(root, loader, extensions, transform=None, target_transform=None) A generic data loader where the samples are arranged in this way:
root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext
Parameters: - root (string) – Root directory path.
- loader (callable) – A function to load a sample given its path.
- extensions (list[string]) – A list of allowed extensions.
- transform (callable, optional) – A function/transform that takes in
a sample and returns a transformed version.
E.g,
transforms.RandomCrop
for images. - target_transform – A function/transform that takes in the target and transforms it.
Imagenet-12
This should simply be implemented with an ImageFolder
dataset.
The data is preprocessed as described
here
CIFAR
-
class
torchvision.datasets.
CIFAR10
(root, train=True, transform=None, target_transform=None, download=False) CIFAR10 Dataset.
Parameters: - root (string) – Root directory of dataset where directory
cifar-10-batches-py
exists or will be saved to if download is set to True. - train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- root (string) – Root directory of dataset where directory
-
class
torchvision.datasets.
CIFAR100
(root, train=True, transform=None, target_transform=None, download=False) CIFAR100 Dataset.
This is a subclass of the CIFAR10 Dataset.
STL10
-
class
torchvision.datasets.
STL10
(root, split='train', transform=None, target_transform=None, download=False) STL10 Dataset.
Parameters: - root (string) – Root directory of dataset where directory
stl10_binary
exists. - split (string) – One of {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’}. Accordingly dataset is selected.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- root (string) – Root directory of dataset where directory
SVHN
-
class
torchvision.datasets.
SVHN
(root, split='train', transform=None, target_transform=None, download=False) SVHN Dataset. Note: The SVHN dataset assigns the label 10 to the digit 0. However, in this Dataset, we assign the label 0 to the digit 0 to be compatible with PyTorch loss functions which expect the class labels to be in the range [0, C-1]
Parameters: - root (string) – Root directory of dataset where directory
SVHN
exists. - split (string) – One of {‘train’, ‘test’, ‘extra’}. Accordingly dataset is selected. ‘extra’ is Extra training set.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- root (string) – Root directory of dataset where directory
PhotoTour
-
class
torchvision.datasets.
PhotoTour
(root, name, train=True, transform=None, download=False) Learning Local Image Descriptors Data Dataset.
Parameters: - root (string) – Root directory where images are.
- name (string) – Name of the dataset to load.
- transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
SBU
-
class
torchvision.datasets.
SBU
(root, transform=None, target_transform=None, download=True) SBU Captioned Photo Dataset.
Parameters: - root (string) – Root directory of dataset where tarball
SBUCaptionedPhotoDataset.tar.gz
exists. - transform (callable, optional) – A function/transform that takes in a PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- download (bool, optional) – If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- root (string) – Root directory of dataset where tarball
Flickr
-
class
torchvision.datasets.
Flickr8k
(root, ann_file, transform=None, target_transform=None) Flickr8k Entities Dataset.
Parameters: - root (string) – Root directory where images are downloaded to.
- ann_file (string) – Path to annotation file.
- transform (callable, optional) – A function/transform that takes in a PIL image
and returns a transformed version. E.g,
transforms.ToTensor
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
-
class
torchvision.datasets.
Flickr30k
(root, ann_file, transform=None, target_transform=None) Flickr30k Entities Dataset.
Parameters: - root (string) – Root directory where images are downloaded to.
- ann_file (string) – Path to annotation file.
- transform (callable, optional) – A function/transform that takes in a PIL image
and returns a transformed version. E.g,
transforms.ToTensor
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
VOC
-
class
torchvision.datasets.
VOCSegmentation
(root, year='2012', image_set='train', download=False, transform=None, target_transform=None) Pascal VOC Segmentation Dataset.
Parameters: - root (string) – Root directory of the VOC Dataset.
- year (string, optional) – The dataset year, supports years 2007 to 2012.
- image_set (string, optional) – Select the image_set to use,
train
,trainval
orval
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
-
class
torchvision.datasets.
VOCDetection
(root, year='2012', image_set='train', download=False, transform=None, target_transform=None) Pascal VOC Detection Dataset.
Parameters: - root (string) – Root directory of the VOC Dataset.
- year (string, optional) – The dataset year, supports years 2007 to 2012.
- image_set (string, optional) – Select the image_set to use,
train
,trainval
orval
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. (default: alphabetic indexing of VOC’s 20 classes).
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, required) – A function/transform that takes in the target and transforms it.
Cityscapes
Note
Requires Cityscape to be downloaded.
-
class
torchvision.datasets.
Cityscapes
(root, split='train', mode='fine', target_type='instance', transform=None, target_transform=None) Cityscapes Dataset.
Parameters: - root (string) – Root directory of dataset where directory
leftImg8bit
andgtFine
orgtCoarse
are located. - split (string, optional) – The image split to use,
train
,test
orval
if mode=”gtFine” otherwisetrain
,train_extra
orval
- mode (string, optional) – The quality mode to use,
gtFine
orgtCoarse
- target_type (string or list, optional) – Type of target to use,
instance
,semantic
,polygon
orcolor
. Can also be a list to output a tuple with all specified target types. - transform (callable, optional) – A function/transform that takes in a PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
Examples
Get semantic segmentation target
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', target_type='semantic') img, smnt = dataset[0]
Get multiple targets
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', target_type=['instance', 'color', 'polygon']) img, (inst, col, poly) = dataset[0]
Validate on the “coarse” set
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse', target_type='semantic') img, smnt = dataset[0]
- root (string) – Root directory of dataset where directory