Stateful DataLoader Tutorial¶
Saving and loading state¶
Stateful DataLoader adds the load_state_dict
, state_dict
methods to the torch.utils.data.DataLoader
. State fetch and set can be done as follows:
from torchdata.stateful_dataloader import StatefulDataLoader
dataloader = StatefulDataLoader(dataset, num_workers=2)
for i, batch in enumerate(dataloader):
...
if i == 10:
state_dict = dataloader.state_dict()
break
# Training run resumes with the previous checkpoint
dataloader = StatefulDataLoader(dataset, num_workers=2)
# Resume state with DataLoader
dataloader.load_state_dict(state_dict)
for i, batch in enumerate(dataloader):
...
Saving Custom State with Map-Style Datasets¶
For efficient resuming of Map-style datasets, you can resume iteration by defining state_dict
/ load_state_dict
methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add state_dict
/ load_state_dict
methods to your dataset.
from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader
# If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary
class MySampler(torch.utils.data.Sampler[int]):
def __init__(self, high: int, seed: int, limit: int):
self.seed, self.high, self.limit = seed, high, limit
self.g = torch.Generator()
self.g.manual_seed(self.seed)
self.i = 0
def __iter__(self):
while self.i < self.limit:
val = int(torch.randint(high=self.high, size=(1,), generator=self.g))
self.i += 1
yield val
def load_state_dict(self, state_dict: Dict[str, Any]):
self.i = state_dict["i"]
self.g.set_state(state_dict["rng"])
def state_dict(self) -> Dict[str, Any]:
return {"i": self.i, "rng": self.g.get_state()}
# Optional: save dataset random transform state
class NoisyRange(torch.utils.data.Dataset):
def __init__(self, high: int, mean: float, std: float):
self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std)
def __len__(self):
return self.high
def __getitem__(self, idx: int) -> float:
if not (0 <= idx < self.high):
raise IndexError()
x = torch.normal(self.mean, self.std)
noise = x.item()
return idx + noise
def load_state_dict(self, state_dict):
torch.set_rng_state(state_dict["rng"])
def state_dict(self):
return {"rng": torch.get_rng_state()}
# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10),
batch_size=2, drop_last=False, num_workers=num_workers)
batches = []
for i, batch in enumerate(dl):
batches.append(batch)
if i == 2:
sd = dl.state_dict()
dl.load_state_dict(sd)
batches2 = list(dl)
print(batches[3:])
print(batches2)
"""
Output:
num_workers=0
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
num_workers=2
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
"""
Saving Custom State with Iterable-Style Datasets¶
Tracking iteration order with Iterable-style datasets requires state from each worker-level instance of the dataset to be captured. You can define state_dict
/ load_state_dict
methods on your dataset which capture worker-level state. StatefulDataLoader
will handle aggregation across workers and distribution back to the workers. Calling load_state_dict
requires StatefulDataLoader`
to have same num_workers
as those of the provided state_dict
.
from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, high: int, seed: int):
self.high, self.seed = high, seed
self.g = torch.Generator()
self.i = 0
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
num_workers = worker_info.num_workers
else:
worker_id = 0
num_workers = 1
self.g.manual_seed(self.seed)
arr = torch.randperm(self.high, generator=self.g)
arr = arr[worker_id:self.high:num_workers]
for idx in range(self.i, len(arr)):
self.i += 1
yield arr[idx]
self.i = 0
def state_dict(self):
return {"i": self.i}
def load_state_dict(self, state_dict):
self.i = state_dict["i"]
# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(
MyIterableDataset(12, 0), batch_size=2, drop_last=False,
num_workers=num_workers)
batches = []
for i, batch in enumerate(dl):
batches.append(batch)
if i == 2:
sd = dl.state_dict()
dl.load_state_dict(sd)
batches2 = list(dl)
print(batches[3:])
print(batches2)
"""
Output:
num_workers=0
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])]
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])]
num_workers=2
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
"""