Stateful DataLoader Tutorial
============================

Saving and loading state
------------------------

Stateful DataLoader adds the ``load_state_dict``, ``state_dict`` methods to the ``torch.utils.data.DataLoader``. State fetch and set can be done as follows:

.. code:: python

    from torchdata.stateful_dataloader import StatefulDataLoader

    dataloader = StatefulDataLoader(dataset, num_workers=2)
    for i, batch in enumerate(dataloader):
        ...
        if i == 10:
            state_dict = dataloader.state_dict()
            break

    # Training run resumes with the previous checkpoint
    dataloader = StatefulDataLoader(dataset, num_workers=2)
    # Resume state with DataLoader
    dataloader.load_state_dict(state_dict)
    for i, batch in enumerate(dataloader):
        ...

Saving Custom State with Map-Style Datasets
-------------------------------------------

For efficient resuming of `Map-style datasets <https://pytorch.org/docs/stable/data.html#map-style-datasets>`_, you can resume iteration by defining ``state_dict`` / ``load_state_dict`` methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add ``state_dict`` / ``load_state_dict`` methods to your dataset.

.. code:: python

    from typing import *
    import torch
    import torch.utils.data
    from torchdata.stateful_dataloader import StatefulDataLoader

    # If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary
    class MySampler(torch.utils.data.Sampler[int]):
        def __init__(self, high: int, seed: int, limit: int):
            self.seed, self.high, self.limit = seed, high, limit
            self.g = torch.Generator()
            self.g.manual_seed(self.seed)
            self.i = 0

        def __iter__(self):
            while self.i < self.limit:
            val = int(torch.randint(high=self.high, size=(1,), generator=self.g))
            self.i += 1
            yield val

        def load_state_dict(self, state_dict: Dict[str, Any]):
            self.i = state_dict["i"]
            self.g.set_state(state_dict["rng"])

        def state_dict(self) -> Dict[str, Any]:
            return {"i": self.i, "rng": self.g.get_state()}

    # Optional: save dataset random transform state
    class NoisyRange(torch.utils.data.Dataset):
        def __init__(self, high: int, mean: float, std: float):
            self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std)

        def __len__(self):
            return self.high

        def __getitem__(self, idx: int) -> float:
            if not (0 <= idx < self.high):
            raise IndexError()
            x = torch.normal(self.mean, self.std)
            noise = x.item()
            return idx + noise

        def load_state_dict(self, state_dict):
            torch.set_rng_state(state_dict["rng"])

        def state_dict(self):
            return {"rng": torch.get_rng_state()}

    # Test both single/multiprocess dataloading
    for num_workers in [0, 2]:
        print(f"{num_workers=}")
        dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10),
            batch_size=2, drop_last=False, num_workers=num_workers)

    batches = []
    for i, batch in enumerate(dl):
        batches.append(batch)
        if i == 2:
        sd = dl.state_dict()

    dl.load_state_dict(sd)
    batches2 = list(dl)

    print(batches[3:])
    print(batches2)

    """
    Output:
    num_workers=0
    [tensor([-0.4526,  3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
    [tensor([-0.4526,  3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
    num_workers=2
    [tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
    [tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
    """

Saving Custom State with Iterable-Style Datasets
------------------------------------------------

Tracking iteration order with `Iterable-style datasets <https://pytorch.org/docs/stable/data.html#iterable-style-datasets>`_ requires state from each worker-level instance of the dataset to be captured. You can define ``state_dict`` / ``load_state_dict`` methods on your dataset which capture worker-level state. :class:`StatefulDataLoader` will handle aggregation across workers and distribution back to the workers. Calling ``load_state_dict`` requires :class:`StatefulDataLoader`` to have same ``num_workers`` as those of the provided ``state_dict``.

.. code:: python

    from typing import *
    import torch
    import torch.utils.data
    from torchdata.stateful_dataloader import StatefulDataLoader


    class MyIterableDataset(torch.utils.data.IterableDataset):
        def __init__(self, high: int, seed: int):
            self.high, self.seed = high, seed
            self.g = torch.Generator()
            self.i = 0

        def __iter__(self):
            worker_info = torch.utils.data.get_worker_info()
            if worker_info is not None:
            worker_id = worker_info.id
            num_workers = worker_info.num_workers
            else:
            worker_id = 0
            num_workers = 1
            self.g.manual_seed(self.seed)
            arr = torch.randperm(self.high, generator=self.g)
            arr = arr[worker_id:self.high:num_workers]
            for idx in range(self.i, len(arr)):
            self.i += 1
            yield arr[idx]
            self.i = 0

        def state_dict(self):
            return {"i": self.i}

        def load_state_dict(self, state_dict):
            self.i = state_dict["i"]

    # Test both single/multiprocess dataloading
    for num_workers in [0, 2]:
    print(f"{num_workers=}")
    dl = StatefulDataLoader(
        MyIterableDataset(12, 0), batch_size=2, drop_last=False,
        num_workers=num_workers)

    batches = []
    for i, batch in enumerate(dl):
        batches.append(batch)
        if i == 2:
        sd = dl.state_dict()

    dl.load_state_dict(sd)
    batches2 = list(dl)

    print(batches[3:])
    print(batches2)

    """
    Output:
    num_workers=0
    [tensor([ 2, 10]), tensor([3, 1]), tensor([11,  6])]
    [tensor([ 2, 10]), tensor([3, 1]), tensor([11,  6])]
    num_workers=2
    [tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
    [tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
    """