torchtnt.utils.data.CudaDataPrefetcher¶
-
class
torchtnt.utils.data.
CudaDataPrefetcher
(data_iterable: Iterable[Batch], device: device, num_prefetch_batches: int = 1)¶ CudaDataPrefetcher prefetches batches and moves them to the device.
This class can be used to interleave data loading, host-to-device copies, and computation more effectively.
Parameters: - data_iterable – an Iterable containing the data to use for CudaDataPrefetcher construction
- device – the device to which data should be moved
- num_prefetch_batches – number of batches to prefetch
Note
We recommend users leverage memory pinning when constructing their dataloader: https://pytorch.org/docs/stable/data.html#memory-pinning.
Example:
dataloader = ... device = torch.device("cuda") num_prefetch_batches = 2 data_prefetcher = CudaDataPrefetcher(dataloader, device, num_prefetch_batches) for batch in data_prefetcher: # batch is already on device # operate on batch
Methods
__init__
(data_iterable, device[, ...])