Source code for torch.utils.data.dataset

import bisect


[docs]class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
[docs]class TensorDataset(Dataset): """Dataset wrapping data and target tensors. Each sample will be retrieved by indexing both tensors along the first dimension. Arguments: data_tensor (Tensor): contains sample data. target_tensor (Tensor): contains sample targets (labels). """ def __init__(self, data_tensor, target_tensor): assert data_tensor.size(0) == target_tensor.size(0) self.data_tensor = data_tensor self.target_tensor = target_tensor def __getitem__(self, index): return self.data_tensor[index], self.target_tensor[index] def __len__(self): return self.data_tensor.size(0)
[docs]class ConcatDataset(Dataset): """ Dataset to concatenate multiple datasets. Purpose: useful to assemble different existing datasets, possibly large-scale datasets as the concatenation operation is done in an on-the-fly manner. Arguments: datasets (iterable): List of datasets to be concatenated """ @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, datasets): super(ConcatDataset, self).__init__() assert len(datasets) > 0, 'datasets should not be an empty iterable' self.datasets = list(datasets) self.cummulative_sizes = self.cumsum(self.datasets) def __len__(self): return self.cummulative_sizes[-1] def __getitem__(self, idx): dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cummulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx]