FullSync¶
- class torchdata.datapipes.iter.FullSync(datapipe: IterDataPipe, timeout=1800)¶
Synchronizes data across distributed processes to prevent hanging during training, which is caused by uneven sharded data (functional name:
fullsync
). It stops when the shortest distributed shard is exhausted. It would be appended at the end of the graph ofDataPipe
byDistributedReadingService
automatically.- Parameters:
datapipe – IterDataPipe that needs to be synchronized
timeout – Timeout for prefetching data in seconds. Default value equals to 30 minutes
Example
>>> from torchdata.datapipes.iter import IterableWrapper >>> # Distributed training with world size 2 >>> world_size = 2 >>> dp = IterableWrapper(list(range(23))).sharding_filter() >>> torch.utils.data.graph_settings.apply_sharding(dp, world_size, rank) >>> # Rank 0 has 12 elements; Rank 1 has 11 elements >>> for d in dp: ... model(d) # Hanging at the end of epoch due to uneven sharding >>> dp = dp.fullsync() >>> # Both ranks have 11 elements >>> for d in dp: ... model(d) # Not hanging anymore