# mypy: allow-untyped-defsimportbisectimportitertoolsimportmathimportwarningsfromcollections.abcimportSequence# UP006 wants 'Iterable' to be imported from collections.abc but it needs to# stay from typing for now due to BC concerns. In particular several internal# targets fail to typecheck with:# TypeError: Cannot create a consistent method resolution order (MRO) for# bases Iterable, Genericfromtypingimportcast,Generic,Iterable,Optional,TypeVar,Union# noqa: UP035fromtyping_extensionsimportdeprecated# No 'default_generator' in torch/__init__.pyifromtorchimportdefault_generator,Generator,randperm,Tensor__all__=["Dataset","IterableDataset","TensorDataset","StackDataset","ConcatDataset","ChainDataset","Subset","random_split",]_T=TypeVar("_T")_T_co=TypeVar("_T_co",covariant=True)_T_dict=dict[str,_T_co]_T_tuple=tuple[_T_co,...]_T_stack=TypeVar("_T_stack",_T_tuple,_T_dict)
[docs]classDataset(Generic[_T_co]):r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """def__getitem__(self,index)->_T_co:raiseNotImplementedError("Subclasses of Dataset should implement __getitem__.")# def __getitems__(self, indices: List) -> List[_T_co]:# Not implemented to prevent false-positives in fetcher check in# torch.utils.data._utils.fetch._MapDatasetFetcherdef__add__(self,other:"Dataset[_T_co]")->"ConcatDataset[_T_co]":returnConcatDataset([self,other])
# No `def __len__(self)` default?# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]# in pytorch/torch/utils/data/sampler.py
[docs]classIterableDataset(Dataset[_T_co],Iterable[_T_co]):r"""An iterable Dataset. All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream. All subclasses should overwrite :meth:`__iter__`, which would return an iterator of samples in this dataset. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator. When :attr:`num_workers > 0`, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker process, returns information about the worker. It can be used in either the dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's :attr:`worker_init_fn` option to modify each copy's behavior. Example 1: splitting workload across all workers in :meth:`__iter__`:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> # xdoctest: +SKIP("Fails on MacOS12") >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # xdoctest: +REQUIRES(POSIX) >>> # Multi-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6] """def__add__(self,other:Dataset[_T_co]):returnChainDataset([self,other])
# No `def __len__(self)` default? Subclasses raise `TypeError` when needed.# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
[docs]classTensorDataset(Dataset[tuple[Tensor,...]]):r"""Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Args: *tensors (Tensor): tensors that have the same size of the first dimension. """tensors:tuple[Tensor,...]def__init__(self,*tensors:Tensor)->None:assertall(tensors[0].size(0)==tensor.size(0)fortensorintensors),"Size mismatch between tensors"self.tensors=tensorsdef__getitem__(self,index):returntuple(tensor[index]fortensorinself.tensors)def__len__(self):returnself.tensors[0].size(0)
[docs]classStackDataset(Dataset[_T_stack]):r"""Dataset as a stacking of multiple datasets. This class is useful to assemble different parts of complex input data, given as datasets. Example: >>> # xdoctest: +SKIP >>> images = ImageDataset() >>> texts = TextDataset() >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} Args: *args (Dataset): Datasets for stacking returned as tuple. **kwargs (Dataset): Datasets for stacking returned as dict. """datasets:Union[tuple,dict]def__init__(self,*args:Dataset[_T_co],**kwargs:Dataset[_T_co])->None:ifargs:ifkwargs:raiseValueError("Supported either ``tuple``- (via ``args``) or""``dict``- (via ``kwargs``) like input/output, but both types are given.")self._length=len(args[0])# type: ignore[arg-type]ifany(self._length!=len(dataset)fordatasetinargs):# type: ignore[arg-type]raiseValueError("Size mismatch between datasets")self.datasets=argselifkwargs:tmp=list(kwargs.values())self._length=len(tmp[0])# type: ignore[arg-type]ifany(self._length!=len(dataset)fordatasetintmp):# type: ignore[arg-type]raiseValueError("Size mismatch between datasets")self.datasets=kwargselse:raiseValueError("At least one dataset should be passed")def__getitem__(self,index):ifisinstance(self.datasets,dict):return{k:dataset[index]fork,datasetinself.datasets.items()}returntuple(dataset[index]fordatasetinself.datasets)def__getitems__(self,indices:list):# add batched sampling support when parent datasets supports it.ifisinstance(self.datasets,dict):dict_batch:list[_T_dict]=[{}for_inindices]fork,datasetinself.datasets.items():ifcallable(getattr(dataset,"__getitems__",None)):items=dataset.__getitems__(indices)# type: ignore[attr-defined]iflen(items)!=len(indices):raiseValueError("Nested dataset's output size mismatch."f" Expected {len(indices)}, got {len(items)}")fordata,d_sampleinzip(items,dict_batch):d_sample[k]=dataelse:foridx,d_sampleinzip(indices,dict_batch):d_sample[k]=dataset[idx]returndict_batch# tuple datalist_batch:list[list]=[[]for_inindices]fordatasetinself.datasets:ifcallable(getattr(dataset,"__getitems__",None)):items=dataset.__getitems__(indices)# type: ignore[attr-defined]iflen(items)!=len(indices):raiseValueError("Nested dataset's output size mismatch."f" Expected {len(indices)}, got {len(items)}")fordata,t_sampleinzip(items,list_batch):t_sample.append(data)else:foridx,t_sampleinzip(indices,list_batch):t_sample.append(dataset[idx])tuple_batch:list[_T_tuple]=[tuple(sample)forsampleinlist_batch]returntuple_batchdef__len__(self):returnself._length
[docs]classConcatDataset(Dataset[_T_co]):r"""Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (sequence): List of datasets to be concatenated """datasets:list[Dataset[_T_co]]cumulative_sizes:list[int]@staticmethoddefcumsum(sequence):r,s=[],0foreinsequence:l=len(e)r.append(l+s)s+=lreturnrdef__init__(self,datasets:Iterable[Dataset])->None:super().__init__()self.datasets=list(datasets)assertlen(self.datasets)>0,"datasets should not be an empty iterable"# type: ignore[arg-type]fordinself.datasets:assertnotisinstance(d,IterableDataset),"ConcatDataset does not support IterableDataset"self.cumulative_sizes=self.cumsum(self.datasets)def__len__(self):returnself.cumulative_sizes[-1]def__getitem__(self,idx):ifidx<0:if-idx>len(self):raiseValueError("absolute value of index should not exceed dataset length")idx=len(self)+idxdataset_idx=bisect.bisect_right(self.cumulative_sizes,idx)ifdataset_idx==0:sample_idx=idxelse:sample_idx=idx-self.cumulative_sizes[dataset_idx-1]returnself.datasets[dataset_idx][sample_idx]@property@deprecated("`cummulative_sizes` attribute is renamed to `cumulative_sizes`",category=FutureWarning,)defcummulative_sizes(self):returnself.cumulative_sizes
[docs]classChainDataset(IterableDataset):r"""Dataset for chaining multiple :class:`IterableDataset` s. This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. Args: datasets (iterable of IterableDataset): datasets to be chained together """def__init__(self,datasets:Iterable[Dataset])->None:super().__init__()self.datasets=datasetsdef__iter__(self):fordinself.datasets:assertisinstance(d,IterableDataset),"ChainDataset only supports IterableDataset"yield fromddef__len__(self):total=0fordinself.datasets:assertisinstance(d,IterableDataset),"ChainDataset only supports IterableDataset"total+=len(d)# type: ignore[arg-type]returntotal
[docs]classSubset(Dataset[_T_co]):r""" Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """dataset:Dataset[_T_co]indices:Sequence[int]def__init__(self,dataset:Dataset[_T_co],indices:Sequence[int])->None:self.dataset=datasetself.indices=indicesdef__getitem__(self,idx):ifisinstance(idx,list):returnself.dataset[[self.indices[i]foriinidx]]returnself.dataset[self.indices[idx]]def__getitems__(self,indices:list[int])->list[_T_co]:# add batched sampling support when parent dataset supports it.# see torch.utils.data._utils.fetch._MapDatasetFetcherifcallable(getattr(self.dataset,"__getitems__",None)):returnself.dataset.__getitems__([self.indices[idx]foridxinindices])# type: ignore[attr-defined]else:return[self.dataset[self.indices[idx]]foridxinindices]def__len__(self):returnlen(self.indices)
[docs]defrandom_split(dataset:Dataset[_T],lengths:Sequence[Union[int,float]],generator:Optional[Generator]=default_generator,)->list[Subset[_T]]:r""" Randomly split a dataset into non-overlapping new datasets of given lengths. If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided. After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left. Optionally fix the generator for reproducible results, e.g.: Example: >>> # xdoctest: +SKIP >>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) Args: dataset (Dataset): Dataset to be split lengths (sequence): lengths or fractions of splits to be produced generator (Generator): Generator used for the random permutation. """ifmath.isclose(sum(lengths),1)andsum(lengths)<=1:subset_lengths:list[int]=[]fori,fracinenumerate(lengths):iffrac<0orfrac>1:raiseValueError(f"Fraction at index {i} is not between 0 and 1")n_items_in_split=int(math.floor(len(dataset)*frac)# type: ignore[arg-type])subset_lengths.append(n_items_in_split)remainder=len(dataset)-sum(subset_lengths)# type: ignore[arg-type]# add 1 to all the lengths in round-robin fashion until the remainder is 0foriinrange(remainder):idx_to_add_at=i%len(subset_lengths)subset_lengths[idx_to_add_at]+=1lengths=subset_lengthsfori,lengthinenumerate(lengths):iflength==0:warnings.warn(f"Length of split at index {i} is 0. "f"This might result in an empty dataset.")# Cannot verify that dataset is Sizedifsum(lengths)!=len(dataset):# type: ignore[arg-type]raiseValueError("Sum of input lengths does not equal the length of the input dataset!")indices=randperm(sum(lengths),generator=generator).tolist()# type: ignore[arg-type, call-overload]lengths=cast(Sequence[int],lengths)return[Subset(dataset,indices[offset-length:offset])foroffset,lengthinzip(itertools.accumulate(lengths),lengths)]
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.