Attention
June 2024 Status Update: Removing DataPipes and DataLoader V2
We are re-focusing the torchdata repo to be an iterative enhancement of torch.utils.data.DataLoader. We do not plan on continuing development or maintaining the [DataPipes] and [DataLoaderV2] solutions, and they will be removed from the torchdata repo. We’ll also be revisiting the DataPipes references in pytorch/pytorch. In release torchdata==0.8.0 (July 2024) they will be marked as deprecated, and in 0.10.0 (Late 2024) they will be deleted. Existing users are advised to pin to torchdata<=0.9.0 or an older version until they are able to migrate away. Subsequent releases will not include DataPipes or DataLoaderV2. Please reach out if you suggestions or comments (please use this issue for feedback)
Getting Started With torchdata.nodes
(beta)¶
Install torchdata with pip.
pip install torchdata>=0.10.0
Generator Example¶
Wrap a generator (or any iterable) to convert it to a BaseNode and get started
from torchdata.nodes import IterableWrapper, ParallelMapper, Loader
node = IterableWrapper(range(10))
node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread")
loader = Loader(node)
result = list(loader)
print(result)
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
Sampler Example¶
Samplers are still supported, and you can use your existing
torch.utils.data.Dataset
's. See Migrating to torchdata.nodes from torch.utils.data for an in-depth
example.
import torch.utils.data
from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader
class SquaredDataset(torch.utils.data.Dataset):
def __getitem__(self, i: int) -> int:
return i**2
def __len__(self):
return 10
dataset = SquaredDataset()
sampler = RandomSampler(dataset)
# For fine-grained control of iteration order, define your own sampler
node = SamplerWrapper(sampler)
# Simply apply dataset's __getitem__ as a map function to the indices generated from sampler
node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread")
# Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs
loader = Loader(node)
print(list(loader))
# [25, 36, 9, 49, 0, 81, 4, 16, 64, 1]
print(list(loader))
# [0, 4, 1, 64, 49, 25, 9, 16, 81, 36]