torch.utils.data

class torch.utils.data.Dataset[source]

An abstract class representing a Dataset.

All other datasets should subclass it. All subclasses should override __len__, that provides the size of the dataset, and __getitem__, supporting integer indexing in range from 0 to len(self) exclusive.

class torch.utils.data.TensorDataset(data_tensor, target_tensor)[source]

Dataset wrapping data and target tensors.

Each sample will be retrieved by indexing both tensors along the first dimension.

Parameters:
  • data_tensor (Tensor) – contains sample data.
  • target_tensor (Tensor) – contains sample targets (labels).
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False)[source]

Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.

Parameters:
  • dataset (Dataset) – dataset from which to load the data.
  • batch_size (int, optional) – how many samples per batch to load (default: 1).
  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
  • sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified, the shuffle argument is ignored.
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process (default: 0)
  • collate_fn (callable, optional) –
  • pin_memory (bool, optional) –