At the heart of PyTorch data loading utility is the
class. It represents a Python iterable over a dataset, with support for
These options are configured by the constructor arguments of a
DataLoader, which has signature:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
The sections below describe in details the effects and usages of these options.
The most important argument of
dataset, which indicates a dataset object to load data
from. PyTorch supports two different types of datasets:
A map-style dataset is one that implements the
__len__() protocols, and represents a map from (possibly non-integral)
indices/keys to data samples.
For example, such a dataset, when accessed with
dataset[idx], could read
idx-th image and its corresponding label from a folder on the disk.
Dataset for more details.
An iterable-style dataset is an instance of a subclass of
that implements the
__iter__() protocol, and represents an iterable over
data samples. This type of datasets is particularly suitable for cases where
random reads are expensive or even improbable, and where the batch size depends
on the fetched data.
For example, such a dataset, when called
iter(dataset), could return a
stream of data reading from a database, a remote server, or even logs generated
in real time.
IterableDataset for more details.
For iterable-style datasets, data loading order is entirely controlled by the user-defined iterable. This allows easier implementations of chunk-reading and dynamic batch size (e.g., by yielding a batched sample at each time).
The rest of this section concerns the case with
classes are used to specify the sequence of indices/keys used in data loading.
They represent iterable objects over the indices to datasets. E.g., in the
common case with stochastic gradient decent (SGD), a
Sampler could randomly permute a list of indices
and yield each one at a time, or yield a small number of them for mini-batch
A sequential or shuffled sampler will be automatically constructed based on the
shuffle argument to a
Alternatively, users may use the
sampler argument to specify a
Sampler object that at each time yields
the next index/key to fetch.
Sampler that yields a list of batch
indices at a time can be passed as the
Automatic batching can also be enabled via
drop_last arguments. See
the next section for more details
batch_sampler is compatible with
iterable-style datasets, since such datasets have no notion of a key or an
Loading Batched and Non-Batched Data¶
DataLoader supports automatically collating
individual fetched data samples into batches via arguments
Automatic batching (default)¶
This is the most common case, and corresponds to fetching a minibatch of data and collating them into batched samples, i.e., containing Tensors with one dimension being the batch dimension (usually the first).
1) is not
None, the data loader yields
batched samples instead of individual samples.
drop_last arguments are used to specify how the data loader obtains
batches of dataset keys. For map-style datasets, users can alternatively
batch_sampler, which yields a list of keys at a time.
drop_last arguments essentially are used
to construct a
sampler. For map-style
sampler is either provided by user or constructed
based on the
shuffle argument. For iterable-style datasets, the
sampler is a dummy infinite one. See
this section on more details on
After fetching a list of samples using the indices from sampler, the function
passed as the
collate_fn argument is used to collate lists of samples
In this case, loading from a map-style dataset is roughly equivalent with:
for indices in batch_sampler: yield collate_fn([dataset[i] for i in indices])
and loading from an iterable-style dataset is roughly equivalent with:
dataset_iter = iter(dataset) for indices in batch_sampler: yield collate_fn([next(dataset_iter) for _ in indices])
collate_fn can be used to customize collation, e.g., padding
sequential data to max length of a batch. See
this section on more about
Disable automatic batching¶
In certain cases, users may want to handle batching manually in dataset code,
or simply load individual samples. For example, it could be cheaper to directly
load batched data (e.g., bulk reads from a database or reading continuous
chunks of memory), or the batch size is data dependent, or the program is
designed to work on individual samples. Under these scenarios, it’s likely
better to not use automatic batching (where
collate_fn is used to
collate the samples), but let the data loader directly return each member of
batch_sampler is already
None), automatic batching is
disabled. Each sample obtained from the
dataset is processed with the
function passed as the
When automatic batching is disabled, the default
converts NumPy arrays into PyTorch Tensors, and keeps everything else untouched.
In this case, loading from a map-style dataset is roughly equivalent with:
for index in sampler: yield collate_fn(dataset[index])
and loading from an iterable-style dataset is roughly equivalent with:
for data in iter(dataset): yield collate_fn(data)
See this section on more about
The use of
collate_fn is slightly different when automatic batching is
enabled or disabled.
When automatic batching is disabled,
collate_fn is called with
each individual data sample, and the output is yielded from the data loader
iterator. In this case, the default
collate_fn simply converts NumPy
arrays in PyTorch tensors.
When automatic batching is enabled,
collate_fn is called with a list
of data samples at each time. It is expected to collate the input samples into
a batch for yielding from the data loader iterator. The rest of this section
describes behavior of the default
collate_fn in this case.
For instance, if each data sample consists of a 3-channel image and an integral
class label, i.e., each element of the dataset returns a tuple
(image, class_index), the default
collate_fn collates a list of
such tuples into a single tuple of a batched image tensor and a batched class
label Tensor. In particular, the default
collate_fn has the following
It always prepends a new dimension as the batch dimension.
It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.
It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for
Users may use customized
collate_fn to achieve custom batching, e.g.,
collating along a dimension other than the first, padding sequences of
various lengths, or adding support for custom data types.
Single- and Multi-process Data Loading¶
DataLoader uses single-process data loading by
Within a Python process, the
Global Interpreter Lock (GIL)
prevents true fully parallelizing Python code across threads. To avoid blocking
computation code with data loading, PyTorch provides an easy switch to perform
multi-process data loading by simply setting the argument
to a positive integer.
Single-process data loading (default)¶
In this mode, data fetching is done in the same process a
DataLoader is initialized. Therefore, data loading
may block computing. However, this mode may be preferred when resource(s) used
for sharing data among processes (e.g., shared memory, file descriptors) is
limited, or when the entire dataset is small and can be loaded entirely in
memory. Additionally, single-process loading often shows more readable error
traces and thus is useful for debugging.
Multi-process data loading¶
Setting the argument
num_workers as a positive integer will
turn on multi-process data loading with the specified number of loader worker
In this mode, each time an iterator of a
is created (e.g., when you call
worker processes are created. At this point, the
worker_init_fn are passed to each
worker, where they are used to initialize, and fetch data. This means that
dataset access together with its internal IO, transforms
collate_fn) runs in the worker process.
torch.utils.data.get_worker_info() returns various useful information
in a worker process (including the worker id, dataset replica, initial seed,
etc.), and returns
None in main process. Users may use this function in
dataset code and/or
worker_init_fn to individually configure each
dataset replica, and to determine whether the code is running in a worker
process. For example, this can be particularly helpful in sharding the dataset.
For map-style datasets, the main process generates the indices using
sampler and sends them to the workers. So any shuffle randomization is
done in the main process which guides loading by assigning indices to load.
For iterable-style datasets, since each worker process gets a replica of the
dataset object, naive multi-process loading will often result in
duplicated data. Using
worker_init_fn, users may configure each replica independently. (See
IterableDataset documentations for how to achieve
this. ) For similar reasons, in multi-process loading, the
argument drops the last non-full batch of each worker’s iterable-style dataset
Workers are shut down once the end of the iteration is reached, or when the iterator becomes garbage collected.
It is generally not recommended to return CUDA tensors in multi-process
loading because of many subtleties in using CUDA and sharing CUDA tensors in
multiprocessing (see CUDA in multiprocessing). Instead, we recommend
using automatic memory pinning (i.e., setting
pin_memory=True), which enables fast data transfer to CUDA-enabled
Since workers rely on Python
multiprocessing, worker launch behavior is
different on Windows compared to Unix.
fork()is the default
multiprocessingstart method. Using
fork(), child workers typically can access the
datasetand Python argument functions directly through the cloned address space.
spawn()is the default
multiprocessingstart method. Using
spawn(), another interpreter is launched which runs your main script, followed by the internal worker function that receives the
collate_fnand other arguments through
This separate serialization means that you should take two steps to ensure you are compatible with Windows while using multi-process data loading:
Wrap most of you main script’s code within
if __name__ == '__main__':block, to make sure it doesn’t run again (most likely generating error) when each worker process is launched. You can place your dataset and
DataLoaderinstance creation logic here, as it doesn’t need to be re-executed in workers.
Make sure that any custom
datasetcode is declared as top level definitions, outside of the
__main__check. This ensures that they are available in worker processes. (this is needed since functions are pickled as references only, not
Randomness in multi-process data loading¶
By default, each worker will have its PyTorch seed set to
base_seed + worker_id,
base_seed is a long generated by main process using its RNG (thereby,
consuming a RNG state mandatorily). However, seeds for other libraries may be
duplicated upon initializing workers (e.g., NumPy), causing each worker to return
identical random numbers. (See this section in FAQ.).
worker_init_fn, you may access the PyTorch seed set for each worker
torch.initial_seed(), and use it to seed other libraries before data
Host to GPU copies are much faster when they originate from pinned (page-locked) memory. See Use pinned memory buffers for more details on when and how to use pinned memory generally.
For data loading, passing
pin_memory=True to a
DataLoader will automatically put the fetched data
Tensors in pinned memory, and thus enables faster data transfer to CUDA-enabled
The default memory pinning logic only recognizes Tensors and maps and iterables
containing Tensors. By default, if the pinning logic sees a batch that is a
custom type (which will occur if you have a
collate_fn that returns a
custom batch type), or if each element of your batch is a custom type, the
pinning logic will not recognize them, and it will return that batch (or those
elements) without pinning the memory. To enable memory pinning for custom
batch or data type(s), define a
pin_memory() method on your custom
See the example below.
class SimpleCustomBatch: def __init__(self, data): transposed_data = list(zip(*data)) self.inp = torch.stack(transposed_data, 0) self.tgt = torch.stack(transposed_data, 0) # custom memory pinning method on custom type def pin_memory(self): self.inp = self.inp.pin_memory() self.tgt = self.tgt.pin_memory() return self def collate_wrapper(batch): return SimpleCustomBatch(batch) inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) dataset = TensorDataset(inps, tgts) loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, pin_memory=True) for batch_ndx, sample in enumerate(loader): print(sample.inp.is_pinned()) print(sample.tgt.is_pinned())
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)¶
Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
DataLoadersupports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.
torch.utils.datadocumentation page for more details.
dataset (Dataset) – dataset from which to load the data.
batch_size (int, optional) – how many samples per batch to load (default:
shuffle (bool, optional) – set to
Trueto have the data reshuffled at every epoch (default:
sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified,
batch_sampler (Sampler, optional) – like
sampler, but returns a batch of indices at a time. Mutually exclusive with
num_workers (int, optional) – how many subprocesses to use for data loading.
0means that the data will be loaded in the main process. (default:
collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
pin_memory (bool, optional) – If
True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your
collate_fnreturns a batch that is a custom type, see the example below.
drop_last (bool, optional) – set to
Trueto drop the last incomplete batch, if the dataset size is not divisible by the batch size. If
Falseand the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default:
timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default:
worker_init_fn (callable, optional) – If not
None, this will be called on each worker subprocess with the worker id (an int in
[0, num_workers - 1]) as input, after seeding and before data loading. (default:
spawnstart method is used,
worker_init_fncannot be an unpicklable object, e.g., a lambda function. See Multiprocessing best practices on more details related to multiprocessing in PyTorch.
len(dataloader)heuristic is based on the length of the sampler used. When
len(dataset)(if implemented) is returned instead, regardless of multi-process loading configurations, because PyTorch trust user
datasetcode in correctly handling multi-process loading to avoid duplicate data. See Dataset Types for more details on these two types of datasets and how
IterableDatasetinteracts with Multi-process data loading.
An abstract class representing a
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite
__getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite
__len__(), which is expected to return the size of the dataset by many
Samplerimplementations and the default options of
DataLoaderby default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.
An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite
__iter__(), which would return an iterator of samples in this dataset.
When a subclass is used with
DataLoader, each item in the dataset will be yielded from the
num_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers.
get_worker_info(), when called in a worker process, returns information about the worker. It can be used in either the dataset’s
__iter__()method or the
worker_init_fnoption to modify each copy’s behavior.
Example 1: splitting workload across all workers in
>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20))) [3, 4, 5, 6]
Example 2: splitting workload across all workers using
>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]
Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
*tensors (Tensor) – tensors that have the same size of the first dimension.
Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
datasets (sequence) – List of datasets to be concatenated
Dataset for chainning multiple
This class is useful to assemble different existing dataset streams. The chainning operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient.
datasets (iterable of IterableDataset) – datasets to be chained together
Subset of a dataset at specified indices.
dataset (Dataset) – The whole Dataset
indices (sequence) – Indices in the whole set selected for subset
Returns the information about the current
DataLoaderiterator worker process.
When called in a worker, this returns an object guaranteed to have the following attributes:
id: the current worker id.
num_workers: the total number of workers.
seed: the random seed set for the current worker. This value is determined by main process RNG and the worker id. See
DataLoader’s documentation for more details.
dataset: the copy of the dataset object in this process. Note that this will be a different object in a different process than the one in the main process.
When called in the main process, this returns
When used in a
worker_init_fnpassed over to
DataLoader, this method can be useful to set up each worker process differently, for instance, using
worker_idto configure the
datasetobject to only read a specific fraction of a sharded dataset, or use
seedto seed other libraries used in dataset code (e.g., NumPy).
Randomly split a dataset into non-overlapping new datasets of given lengths.
dataset (Dataset) – Dataset to be split
lengths (sequence) – lengths of splits to be produced
Base class for all Samplers.
Every Sampler subclass has to provide an
__iter__()method, providing a way to iterate over indices of dataset elements, and a
__len__()method that returns the length of the returned iterators.
Samples elements sequentially, always in the same order.
data_source (Dataset) – dataset to sample from
RandomSampler(data_source, replacement=False, num_samples=None)¶
Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify
Samples elements randomly from a given list of indices, without replacement.
indices (sequence) – a sequence of indices
WeightedRandomSampler(weights, num_samples, replacement=True)¶
Samples elements from
[0,..,len(weights)-1]with given probabilities (weights).
weights (sequence) – a sequence of weights, not necessary summing up to one
num_samples (int) – number of samples to draw
replacement (bool) – if
True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2]
BatchSampler(sampler, batch_size, drop_last)¶
Wraps another sampler to yield a mini-batch of indices.
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], ] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)¶
Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
torch.nn.parallel.DistributedDataParallel. In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it.
Dataset is assumed to be of constant size.
dataset – Dataset used for sampling.
num_replicas (optional) – Number of processes participating in distributed training.
rank (optional) – Rank of the current process within num_replicas.
shuffle (optional) – If true (default), sampler will shuffle the indices
In distributed mode, calling the
set_epochmethod is needed to make shuffling work; each process will use the same random seed otherwise.
>>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: