torchtnt.utils.data.RoundRobinIterator¶
-
class
torchtnt.utils.data.
RoundRobinIterator
(individual_dataloaders: Mapping[str, Union[DataLoader, Iterable]], iteration_strategy: RoundRobin)¶ RoundRobinIterator cycles over the dataloader one by one. Iterating order can be defined via RobinRobin strategy.
This supports two stopping mechanisms: 1. ALL_DATASETS_EXHAUSTED: Iterates till the largest dataset is exhausted, while skipping those that are done 2. SMALLEST_DATASET_EXHAUSTED: Stops iteration once the smallest dataset has been exhausted
Returns batches of the format: {dataloader_name: batch_from_dataloader}
Parameters: - individual_dataloaders (Mapping[str, Union[DataLoader, Iterable]]) – A mapping of DataLoaders or Iterables with dataloader name as key
- value. (and dataloader/iterable object as) –
- iteration_strategy (RoundRobin) – A RoundRobin dataclass indicating how the dataloaders are iterated over.
Examples
>>> loaders = {'a': torch.utils.data.DataLoader(range(4), batch_size=4), 'b': torch.utils.data.DataLoader(range(15), batch_size=5)} >>> round_robin_strategy = RoundRobin( stopping_mechanism=StoppingMechanism.ALL_DATASETS_EXHAUSTED ) >>> combined_iterator = RoundRobinIterator(loaders, round_robin_strategy) >>> for item in combined_iterator: print(item) {'a': tensor([0, 1, 2, 3])} {'b': tensor([0, 1, 2, 3, 4])} {'b': tensor([5, 6, 7, 8, 9])} {'b': tensor([10, 11, 12, 13, 14])}
-
__init__
(individual_dataloaders: Mapping[str, Union[DataLoader, Iterable]], iteration_strategy: RoundRobin) None ¶
Methods
__init__
(individual_dataloaders, ...)