Attention
June 2024 Status Update: Removing DataPipes and DataLoader V2
We are re-focusing the torchdata repo to be an iterative enhancement of torch.utils.data.DataLoader. We do not plan on continuing development or maintaining the [DataPipes] and [DataLoaderV2] solutions, and they will be removed from the torchdata repo. We’ll also be revisiting the DataPipes references in pytorch/pytorch. In release torchdata==0.8.0 (July 2024) they will be marked as deprecated, and in 0.10.0 (Late 2024) they will be deleted. Existing users are advised to pin to torchdata<=0.9.0 or an older version until they are able to migrate away. Subsequent releases will not include DataPipes or DataLoaderV2. Please reach out if you suggestions or comments (please use this issue for feedback)
Stateful DataLoader Tutorial¶
Saving and loading state¶
Stateful DataLoader adds the load_state_dict
, state_dict
methods to the torch.utils.data.DataLoader
. State fetch and set can be done as follows:
from torchdata.stateful_dataloader import StatefulDataLoader
dataloader = StatefulDataLoader(dataset, num_workers=2)
for i, batch in enumerate(dataloader):
...
if i == 10:
state_dict = dataloader.state_dict()
break
# Training run resumes with the previous checkpoint
dataloader = StatefulDataLoader(dataset, num_workers=2)
# Resume state with DataLoader
dataloader.load_state_dict(state_dict)
for i, batch in enumerate(dataloader):
...
Saving Custom State with Map-Style Datasets¶
For efficient resuming of Map-style datasets, you can resume iteration by defining state_dict
/ load_state_dict
methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add state_dict
/ load_state_dict
methods to your dataset.
from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader
# If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary
class MySampler(torch.utils.data.Sampler[int]):
def __init__(self, high: int, seed: int, limit: int):
self.seed, self.high, self.limit = seed, high, limit
self.g = torch.Generator()
self.g.manual_seed(self.seed)
self.i = 0
def __iter__(self):
while self.i < self.limit:
val = int(torch.randint(high=self.high, size=(1,), generator=self.g))
self.i += 1
yield val
def load_state_dict(self, state_dict: Dict[str, Any]):
self.i = state_dict["i"]
self.g.set_state(state_dict["rng"])
def state_dict(self) -> Dict[str, Any]:
return {"i": self.i, "rng": self.g.get_state()}
# Optional: save dataset random transform state
class NoisyRange(torch.utils.data.Dataset):
def __init__(self, high: int, mean: float, std: float):
self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std)
def __len__(self):
return self.high
def __getitem__(self, idx: int) -> float:
if not (0 <= idx < self.high):
raise IndexError()
x = torch.normal(self.mean, self.std)
noise = x.item()
return idx + noise
def load_state_dict(self, state_dict):
torch.set_rng_state(state_dict["rng"])
def state_dict(self):
return {"rng": torch.get_rng_state()}
# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10),
batch_size=2, drop_last=False, num_workers=num_workers)
batches = []
for i, batch in enumerate(dl):
batches.append(batch)
if i == 2:
sd = dl.state_dict()
dl.load_state_dict(sd)
batches2 = list(dl)
print(batches[3:])
print(batches2)
"""
Output:
num_workers=0
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
num_workers=2
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
"""
Saving Custom State with Iterable-Style Datasets¶
Tracking iteration order with Iterable-style datasets requires state from each worker-level instance of the dataset to be captured. You can define state_dict
/ load_state_dict
methods on your dataset which capture worker-level state. StatefulDataLoader
will handle aggregation across workers and distribution back to the workers. Calling load_state_dict
requires StatefulDataLoader`
to have same num_workers
as those of the provided state_dict
.
from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, high: int, seed: int):
self.high, self.seed = high, seed
self.g = torch.Generator()
self.i = 0
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
num_workers = worker_info.num_workers
else:
worker_id = 0
num_workers = 1
self.g.manual_seed(self.seed)
arr = torch.randperm(self.high, generator=self.g)
arr = arr[worker_id:self.high:num_workers]
for idx in range(self.i, len(arr)):
self.i += 1
yield arr[idx]
self.i = 0
def state_dict(self):
return {"i": self.i}
def load_state_dict(self, state_dict):
self.i = state_dict["i"]
# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(
MyIterableDataset(12, 0), batch_size=2, drop_last=False,
num_workers=num_workers)
batches = []
for i, batch in enumerate(dl):
batches.append(batch)
if i == 2:
sd = dl.state_dict()
dl.load_state_dict(sd)
batches2 = list(dl)
print(batches[3:])
print(batches2)
"""
Output:
num_workers=0
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])]
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])]
num_workers=2
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
"""