torchtnt.utils.data.RandomizedBatchSamplerIterator¶
-
class
torchtnt.utils.data.
RandomizedBatchSamplerIterator
(individual_dataloaders: Mapping[str, Union[DataLoader, Iterable]], iteration_strategy: RandomizedBatchSampler)¶ RandomizedBatchSamplerIterator randomly samples from each dataset using the provided weights.
By default, the iterator stops after all datasets are exhausted. This can be changed by setting another stopping mechanism.
Returns batches of the format: {dataloader_name: batch_from_dataloader}
Parameters: - individual_dataloaders (Mapping[str, Union[DataLoader, Iterable]]) – A mapping of DataLoaders or Iterables with dataloader name as key
- value. (and dataloader/iterable object as) –
- iteration_strategy (RandomizedBatchSampler) – A RandomizedBatchSampler dataclass indicating how the dataloaders are iterated over.
Examples
>>> loaders = {'a': torch.utils.data.DataLoader(range(4), batch_size=4), 'b': torch.utils.data.DataLoader(range(15), batch_size=5)} >>> randomized_batch_sampler = RandomizedBatchSampler( stopping_mechanism=StoppingMechanism.ALL_DATASETS_EXHAUSTED ) >>> combined_iterator = RandomizedBatchSamplerIterator(loaders, randomized_batch_sampler) >>> for item in combined_iterator: print(item) {'b': tensor([0, 1, 2, 3, 4])} {'b': tensor([5, 6, 7, 8, 9])} {'a': tensor([0, 1, 2, 3])} {'b': tensor([10, 11, 12, 13, 14])}
-
__init__
(individual_dataloaders: Mapping[str, Union[DataLoader, Iterable]], iteration_strategy: RandomizedBatchSampler) None ¶
Methods
__init__
(individual_dataloaders, ...)