Source code for torch.utils.data.sampler
# mypy: allow-untyped-defs
from typing import (
Generic,
Iterable,
Iterator,
List,
Optional,
Sequence,
Sized,
TypeVar,
Union,
)
import torch
__all__ = [
"BatchSampler",
"RandomSampler",
"Sampler",
"SequentialSampler",
"SubsetRandomSampler",
"WeightedRandomSampler",
]
_T_co = TypeVar("_T_co", covariant=True)
[docs]class Sampler(Generic[_T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices or lists of indices (batches) of dataset elements,
and may provide a :meth:`__len__` method that returns the length of the returned iterators.
Args:
data_source (Dataset): This argument is not used and will be removed in 2.2.0.
You may still have custom implementation that utilizes it.
Example:
>>> # xdoctest: +SKIP
>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>> def __init__(self, data: List[str]) -> None:
>>> self.data = data
>>>
>>> def __len__(self) -> int:
>>> return len(self.data)
>>>
>>> def __iter__(self) -> Iterator[int]:
>>> sizes = torch.tensor([len(x) for x in self.data])
>>> yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>> def __init__(self, data: List[str], batch_size: int) -> None:
>>> self.data = data
>>> self.batch_size = batch_size
>>>
>>> def __len__(self) -> int:
>>> return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>> def __iter__(self) -> Iterator[List[int]]:
>>> sizes = torch.tensor([len(x) for x in self.data])
>>> for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>> yield batch.tolist()
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source: Optional[Sized] = None) -> None:
if data_source is not None:
import warnings
warnings.warn(
"`data_source` argument is not used and will be removed in 2.2.0."
"You may still have custom implementation that utilizes it."
)
def __iter__(self) -> Iterator[_T_co]:
raise NotImplementedError
# NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
#
# Many times we have an abstract class representing a collection/iterable of
# data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
# implementing a `__len__` method. In such cases, we must make sure to not
# provide a default implementation, because both straightforward default
# implementations have their issues:
#
# + `return NotImplemented`:
# Calling `len(subclass_instance)` raises:
# TypeError: 'NotImplementedType' object cannot be interpreted as an integer
#
# + `raise NotImplementedError`:
# This prevents triggering some fallback behavior. E.g., the built-in
# `list(X)` tries to call `len(X)` first, and executes a different code
# path if the method is not found or `NotImplemented` is returned, while
# raising a `NotImplementedError` will propagate and make the call fail
# where it could have used `__iter__` to complete the call.
#
# Thus, the only two sensible things to do are
#
# + **not** provide a default `__len__`.
#
# + raise a `TypeError` instead, which is what Python uses when users call
# a method that is not defined on an object.
# (@ssnl verifies that this works on at least Python 3.7.)
[docs]class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Args:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
[docs]class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Args:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`.
generator (Generator): Generator used in sampling.
"""
data_source: Sized
replacement: bool
def __init__(
self,
data_source: Sized,
replacement: bool = False,
num_samples: Optional[int] = None,
generator=None,
) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
if not isinstance(self.replacement, bool):
raise TypeError(
f"replacement should be a boolean value, but got replacement={self.replacement}"
)
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError(
f"num_samples should be a positive integer value, but got num_samples={self.num_samples}"
)
@property
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(
high=n, size=(32,), dtype=torch.int64, generator=generator
).tolist()
yield from torch.randint(
high=n,
size=(self.num_samples % 32,),
dtype=torch.int64,
generator=generator,
).tolist()
else:
for _ in range(self.num_samples // n):
yield from torch.randperm(n, generator=generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()[
: self.num_samples % n
]
def __len__(self) -> int:
return self.num_samples
[docs]class SubsetRandomSampler(Sampler[int]):
r"""Samples elements randomly from a given list of indices, without replacement.
Args:
indices (sequence): a sequence of indices
generator (Generator): Generator used in sampling.
"""
indices: Sequence[int]
def __init__(self, indices: Sequence[int], generator=None) -> None:
self.indices = indices
self.generator = generator
def __iter__(self) -> Iterator[int]:
for i in torch.randperm(len(self.indices), generator=self.generator):
yield self.indices[i]
def __len__(self) -> int:
return len(self.indices)
[docs]class WeightedRandomSampler(Sampler[int]):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
Args:
weights (sequence) : a sequence of weights, not necessary summing up to one
num_samples (int): number of samples to draw
replacement (bool): if ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row.
generator (Generator): Generator used in sampling.
Example:
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
"""
weights: torch.Tensor
num_samples: int
replacement: bool
def __init__(
self,
weights: Sequence[float],
num_samples: int,
replacement: bool = True,
generator=None,
) -> None:
if (
not isinstance(num_samples, int)
or isinstance(num_samples, bool)
or num_samples <= 0
):
raise ValueError(
f"num_samples should be a positive integer value, but got num_samples={num_samples}"
)
if not isinstance(replacement, bool):
raise ValueError(
f"replacement should be a boolean value, but got replacement={replacement}"
)
weights_tensor = torch.as_tensor(weights, dtype=torch.double)
if len(weights_tensor.shape) != 1:
raise ValueError(
"weights should be a 1d sequence but given "
f"weights have shape {tuple(weights_tensor.shape)}"
)
self.weights = weights_tensor
self.num_samples = num_samples
self.replacement = replacement
self.generator = generator
def __iter__(self) -> Iterator[int]:
rand_tensor = torch.multinomial(
self.weights, self.num_samples, self.replacement, generator=self.generator
)
yield from iter(rand_tensor.tolist())
def __len__(self) -> int:
return self.num_samples
[docs]class BatchSampler(Sampler[List[int]]):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(
self,
sampler: Union[Sampler[int], Iterable[int]],
batch_size: int,
drop_last: bool,
) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if (
not isinstance(batch_size, int)
or isinstance(batch_size, bool)
or batch_size <= 0
):
raise ValueError(
f"batch_size should be a positive integer value, but got batch_size={batch_size}"
)
if not isinstance(drop_last, bool):
raise ValueError(
f"drop_last should be a boolean value, but got drop_last={drop_last}"
)
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self) -> Iterator[List[int]]:
# Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
if self.drop_last:
sampler_iter = iter(self.sampler)
while True:
try:
batch = [next(sampler_iter) for _ in range(self.batch_size)]
yield batch
except StopIteration:
break
else:
batch = [0] * self.batch_size
idx_in_batch = 0
for idx in self.sampler:
batch[idx_in_batch] = idx
idx_in_batch += 1
if idx_in_batch == self.batch_size:
yield batch
idx_in_batch = 0
batch = [0] * self.batch_size
if idx_in_batch > 0:
yield batch[:idx_in_batch]
def __len__(self) -> int:
# Can only be called if self.sampler has __len__ implemented
# We cannot enforce this condition, so we turn off typechecking for the
# implementation below.
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore[arg-type]
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]