Shortcuts

Data

This module provides helper classes to implement fault tolerant data loaders.

We recommend using torchdata’s StatefulDataLoader to checkpoint each replica’s dataloader frequently to avoid duplicate batches.

class torchft.data.DistributedSampler(dataset: Dataset, replica_group: int, num_replica_groups: int, rank: Optional[int] = None, num_replicas: Optional[int] = None, **kwargs: object)[source]

Bases: DistributedSampler

DistributedSampler extends the standard PyTorch DistributedSampler with a num_replica_groups that is used to shard the data across the fault tolerance replica groups.

torchft doesn’t know how many replica groups ahead of time so we need to set this to be the max number.

This sampler is inherently lossy when used with torchft. torchft occasionally drops batches on rejoining and if a replica group is down that group examples will never be used. This can lead to imbalances if using a small dataset.

This will shard the input dataset into num_replicas*num_replica_group number of shards.

Each shard rank is calculated via: rank + num_replicas*replica_group

num_replicas and replica_group must be the same on all workers.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources