Data¶
This module provides helper classes to implement fault tolerant data loaders.
We recommend using torchdata’s StatefulDataLoader to checkpoint each replica’s dataloader frequently to avoid duplicate batches.
- class torchft.data.DistributedSampler(dataset: Dataset, replica_group: int, num_replica_groups: int, rank: Optional[int] = None, num_replicas: Optional[int] = None, **kwargs: object)[source]¶
Bases:
DistributedSampler
DistributedSampler extends the standard PyTorch DistributedSampler with a num_replica_groups that is used to shard the data across the fault tolerance replica groups.
torchft doesn’t know how many replica groups ahead of time so we need to set this to be the max number.
This sampler is inherently lossy when used with torchft. torchft occasionally drops batches on rejoining and if a replica group is down that group examples will never be used. This can lead to imbalances if using a small dataset.
This will shard the input dataset into
num_replicas*num_replica_group
number of shards.Each shard rank is calculated via:
rank + num_replicas*replica_group
num_replicas and replica_group must be the same on all workers.