• Docs >
  • Utils >
  • torchtnt.utils.data.RandomizedBatchSamplerIterator
Shortcuts

torchtnt.utils.data.RandomizedBatchSamplerIterator

class torchtnt.utils.data.RandomizedBatchSamplerIterator(individual_dataloaders: Mapping[str, Union[DataLoader, Iterable]], iteration_strategy: RandomizedBatchSampler)

RandomizedBatchSamplerIterator randomly samples from each dataset using the provided weights.

By default, the iterator stops after all datasets are exhausted. This can be changed by setting another stopping mechanism.

Returns batches of the format: {dataloader_name: batch_from_dataloader}

Parameters:
  • individual_dataloaders (Mapping[str, Union[DataLoader, Iterable]]) – A mapping of DataLoaders or Iterables with dataloader name as key
  • value. (and dataloader/iterable object as) –
  • iteration_strategy (RandomizedBatchSampler) – A RandomizedBatchSampler dataclass indicating how the dataloaders are iterated over.

Examples

>>> loaders = {'a': torch.utils.data.DataLoader(range(4), batch_size=4),
    'b': torch.utils.data.DataLoader(range(15), batch_size=5)}
>>> randomized_batch_sampler = RandomizedBatchSampler(
        stopping_mechanism=StoppingMechanism.ALL_DATASETS_EXHAUSTED
    )
>>> combined_iterator = RandomizedBatchSamplerIterator(loaders, randomized_batch_sampler)
>>> for item in combined_iterator:
        print(item)
{'b': tensor([0, 1, 2, 3, 4])}
{'b': tensor([5, 6, 7, 8, 9])}
{'a': tensor([0, 1, 2, 3])}
{'b': tensor([10, 11, 12, 13, 14])}
__init__(individual_dataloaders: Mapping[str, Union[DataLoader, Iterable]], iteration_strategy: RandomizedBatchSampler) None

Methods

__init__(individual_dataloaders, ...)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources