torchtnt.utils.data.AllDatasetBatchesIterator¶
-
class
torchtnt.utils.data.
AllDatasetBatchesIterator
(individual_dataloaders: Mapping[str, Union[DataLoader, Iterable]], iteration_strategy: AllDatasetBatches)¶ AllDatasetBatchesIterator returns a dict containing batches from all dataloaders. When the stopping mechanism is set to ALL_DATASETS_EXHAUSTED, it will skip over the finished datasets.
This supports three stopping mechanisms: 1. ALL_DATASETS_EXHAUSTED: Iterates till the largest dataset is exhausted, while skipping those that are done 2. SMALLEST_DATASET_EXHAUSTED: Stops iteration once the smallest dataset has been exhausted 3. RESTART_UNTIL_ALL_DATASETS_EXHAUSTED: Iterates until the largest dataset is exhausted, while restarting those that are done
- Returns batches of the format: {
- dataloader_1_name: batch_obtained_from_dataloader_1, dataloader_2_name: batch_obtained_from_dataloader_2,
}
Parameters: - individual_dataloaders (Mapping[str, Union[DataLoader, Iterable]]) – A mapping of DataLoaders or Iterables with dataloader name as key
- value. (and dataloader/iterable object as) –
- iteration_strategy (AllDatasetBatches) – A AllDatasetBatches dataclass indicating how the dataloaders are iterated over.
Examples
>>> loaders = {'a': torch.utils.data.DataLoader(range(4), batch_size=4), 'b': torch.utils.data.DataLoader(range(15), batch_size=5)} >>> all_dataset_batch_strategy = AllDatasetBatches( stopping_mechanism=StoppingMechanism.ALL_DATASETS_EXHAUSTED ) >>> combined_iterator = AllDatasetBatchesIterator(loaders, all_dataset_batch_strategy) >>> for item in combined_iterator: print(item) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'b': tensor([5, 6, 7, 8, 9])} {'b': tensor([10, 11, 12, 13, 14])}
-
__init__
(individual_dataloaders: Mapping[str, Union[DataLoader, Iterable]], iteration_strategy: AllDatasetBatches) None ¶
Methods
__init__
(individual_dataloaders, ...)