• Docs >
  • Stateful DataLoader Tutorial
Shortcuts

Attention

June 2024 Status Update: Removing DataPipes and DataLoader V2

We are re-focusing the torchdata repo to be an iterative enhancement of torch.utils.data.DataLoader. We do not plan on continuing development or maintaining the [DataPipes] and [DataLoaderV2] solutions, and they will be removed from the torchdata repo. We’ll also be revisiting the DataPipes references in pytorch/pytorch. In release torchdata==0.8.0 (July 2024) they will be marked as deprecated, and in 0.10.0 (Late 2024) they will be deleted. Existing users are advised to pin to torchdata<=0.9.0 or an older version until they are able to migrate away. Subsequent releases will not include DataPipes or DataLoaderV2. Please reach out if you suggestions or comments (please use this issue for feedback)

Stateful DataLoader Tutorial

Saving and loading state

Stateful DataLoader adds the load_state_dict, state_dict methods to the torch.utils.data.DataLoader. State fetch and set can be done as follows:

from torchdata.stateful_dataloader import StatefulDataLoader

dataloader = StatefulDataLoader(dataset, num_workers=2)
for i, batch in enumerate(dataloader):
    ...
    if i == 10:
        state_dict = dataloader.state_dict()
        break

# Training run resumes with the previous checkpoint
dataloader = StatefulDataLoader(dataset, num_workers=2)
# Resume state with DataLoader
dataloader.load_state_dict(state_dict)
for i, batch in enumerate(dataloader):
    ...

Saving Custom State with Map-Style Datasets

For efficient resuming of Map-style datasets, you can resume iteration by defining state_dict / load_state_dict methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add state_dict / load_state_dict methods to your dataset.

from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader

# If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary
class MySampler(torch.utils.data.Sampler[int]):
    def __init__(self, high: int, seed: int, limit: int):
        self.seed, self.high, self.limit = seed, high, limit
        self.g = torch.Generator()
        self.g.manual_seed(self.seed)
        self.i = 0

    def __iter__(self):
        while self.i < self.limit:
        val = int(torch.randint(high=self.high, size=(1,), generator=self.g))
        self.i += 1
        yield val

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.i = state_dict["i"]
        self.g.set_state(state_dict["rng"])

    def state_dict(self) -> Dict[str, Any]:
        return {"i": self.i, "rng": self.g.get_state()}

# Optional: save dataset random transform state
class NoisyRange(torch.utils.data.Dataset):
    def __init__(self, high: int, mean: float, std: float):
        self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std)

    def __len__(self):
        return self.high

    def __getitem__(self, idx: int) -> float:
        if not (0 <= idx < self.high):
        raise IndexError()
        x = torch.normal(self.mean, self.std)
        noise = x.item()
        return idx + noise

    def load_state_dict(self, state_dict):
        torch.set_rng_state(state_dict["rng"])

    def state_dict(self):
        return {"rng": torch.get_rng_state()}

# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
    print(f"{num_workers=}")
    dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10),
        batch_size=2, drop_last=False, num_workers=num_workers)

batches = []
for i, batch in enumerate(dl):
    batches.append(batch)
    if i == 2:
    sd = dl.state_dict()

dl.load_state_dict(sd)
batches2 = list(dl)

print(batches[3:])
print(batches2)

"""
Output:
num_workers=0
[tensor([-0.4526,  3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
[tensor([-0.4526,  3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
num_workers=2
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
"""

Saving Custom State with Iterable-Style Datasets

Tracking iteration order with Iterable-style datasets requires state from each worker-level instance of the dataset to be captured. You can define state_dict / load_state_dict methods on your dataset which capture worker-level state. StatefulDataLoader will handle aggregation across workers and distribution back to the workers. Calling load_state_dict requires StatefulDataLoader` to have same num_workers as those of the provided state_dict.

from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader


class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, high: int, seed: int):
        self.high, self.seed = high, seed
        self.g = torch.Generator()
        self.i = 0

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
        worker_id = worker_info.id
        num_workers = worker_info.num_workers
        else:
        worker_id = 0
        num_workers = 1
        self.g.manual_seed(self.seed)
        arr = torch.randperm(self.high, generator=self.g)
        arr = arr[worker_id:self.high:num_workers]
        for idx in range(self.i, len(arr)):
        self.i += 1
        yield arr[idx]
        self.i = 0

    def state_dict(self):
        return {"i": self.i}

    def load_state_dict(self, state_dict):
        self.i = state_dict["i"]

# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(
    MyIterableDataset(12, 0), batch_size=2, drop_last=False,
    num_workers=num_workers)

batches = []
for i, batch in enumerate(dl):
    batches.append(batch)
    if i == 2:
    sd = dl.state_dict()

dl.load_state_dict(sd)
batches2 = list(dl)

print(batches[3:])
print(batches2)

"""
Output:
num_workers=0
[tensor([ 2, 10]), tensor([3, 1]), tensor([11,  6])]
[tensor([ 2, 10]), tensor([3, 1]), tensor([11,  6])]
num_workers=2
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
"""

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources