import torch
import torch.multiprocessing as multiprocessing
from .sampler import SequentialSampler, RandomSampler, BatchSampler
import collections
import re
import sys
import traceback
import threading
from torch._six import string_classes
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
_use_shared_memory = False
"""Whether to use shared memory in default_collate"""
class ExceptionWrapper(object):
"Wraps an exception plus traceback to communicate across threads"
def __init__(self, exc_info):
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
global _use_shared_memory
_use_shared_memory = True
torch.set_num_threads(1)
while True:
r = index_queue.get()
if r is None:
data_queue.put(None)
break
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
def _pin_memory_loop(in_queue, out_queue, done_event):
while True:
try:
r = in_queue.get()
except Exception:
if done_event.is_set():
return
raise
if r is None:
break
if isinstance(r[1], ExceptionWrapper):
out_queue.put(r)
continue
idx, batch = r
try:
batch = pin_memory_batch(batch)
except Exception:
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
out_queue.put((idx, batch))
numpy_type_map = {
'float64': torch.DoubleTensor,
'float32': torch.FloatTensor,
'float16': torch.HalfTensor,
'int64': torch.LongTensor,
'int32': torch.IntTensor,
'int16': torch.ShortTensor,
'int8': torch.CharTensor,
'uint8': torch.ByteTensor,
}
def default_collate(batch):
"Puts each data field into a tensor with outer dimension batch size"
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if torch.is_tensor(batch[0]):
out = None
if _use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], collections.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError((error_msg.format(type(batch[0]))))
def pin_memory_batch(batch):
if torch.is_tensor(batch):
return batch.pin_memory()
elif isinstance(batch, string_classes):
return batch
elif isinstance(batch, collections.Mapping):
return {k: pin_memory_batch(sample) for k, sample in batch.items()}
elif isinstance(batch, collections.Sequence):
return [pin_memory_batch(sample) for sample in batch]
else:
return batch
class DataLoaderIter(object):
"Iterates once over the DataLoader's dataset, as specified by the sampler"
def __init__(self, loader):
self.dataset = loader.dataset
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory
self.done_event = threading.Event()
self.sample_iter = iter(self.batch_sampler)
if self.num_workers > 0:
self.index_queue = multiprocessing.SimpleQueue()
self.data_queue = multiprocessing.SimpleQueue()
self.batches_outstanding = 0
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))
for _ in range(self.num_workers)]
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()
if self.pin_memory:
in_data = self.data_queue
self.data_queue = queue.Queue()
self.pin_thread = threading.Thread(
target=_pin_memory_loop,
args=(in_data, self.data_queue, self.done_event))
self.pin_thread.daemon = True
self.pin_thread.start()
# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()
def __len__(self):
return len(self.batch_sampler)
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self.data_queue.get()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch)
next = __next__ # Python 2 compatibility
def __iter__(self):
return self
def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queue.put((self.send_idx, indices))
self.batches_outstanding += 1
self.send_idx += 1
def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch
def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError("DataLoaderIterator cannot be pickled")
def _shutdown_workers(self):
if not self.shutdown:
self.shutdown = True
self.done_event.set()
for _ in self.workers:
self.index_queue.put(None)
def __del__(self):
if self.num_workers > 0:
self._shutdown_workers()
[docs]class DataLoader(object):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: 1).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: False).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, ``shuffle`` must be False.
batch_sampler (Sampler, optional): like sampler, but returns a batch of
indices at a time. Mutually exclusive with batch_size, shuffle,
sampler, and drop_last.
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means that the data will be loaded in the main process
(default: 0)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: False)
"""
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)