Attention
June 2024 Status Update: Removing DataPipes and DataLoader V2
We are re-focusing the torchdata repo to be an iterative enhancement of torch.utils.data.DataLoader. We do not plan on continuing development or maintaining the [DataPipes] and [DataLoaderV2] solutions, and they will be removed from the torchdata repo. We’ll also be revisiting the DataPipes references in pytorch/pytorch. In release torchdata==0.8.0 (July 2024) they will be marked as deprecated, and in 0.10.0 (Late 2024) they will be deleted. Existing users are advised to pin to torchdata<=0.9.0 or an older version until they are able to migrate away. Subsequent releases will not include DataPipes or DataLoaderV2. Please reach out if you suggestions or comments (please use this issue for feedback)
Migrating to torchdata.nodes
from torch.utils.data
¶
This guide is intended to help people familiar with torch.utils.data
, or
StatefulDataLoader
,
to get started with torchdata.nodes
, and provide a starting ground for defining
your own dataloading pipelines.
We’ll demonstrate how to achieve the most common DataLoader features, re-use existing samplers and datasets,
and load/save dataloader state. It performs at least as well as DataLoader
and StatefulDataLoader
,
see How does torchdata.nodes perform?.
Map-Style Datasets¶
Let’s look at the DataLoader
constructor args and go from there
class DataLoader:
def __init__(
self,
dataset: Dataset[_T_co],
batch_size: Optional[int] = 1,
shuffle: Optional[bool] = None,
sampler: Union[Sampler, Iterable, None] = None,
batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
num_workers: int = 0,
collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
in_order: bool = True,
):
...
As a referesher, here is roughly how dataloading works in torch.utils.data.DataLoader
:
DataLoader
begins by generating indices from a sampler
and creates batches of batch_size indices.
If no sampler is provided, then a RandomSampler or SequentialSampler is created by default.
The indices are passed to Dataset.__getitem__()
, and then a collate_fn
is applied to the batch
of samples. If num_workers > 0
, it will use multi-processing to create
subprocesses, and pass the batches of indices to the worker processes, who will then call Dataset.__getitem__()
and apply collate_fn
before returning the batches to the main process. At that point, pin_memory
may be applied to the tensors in the batch.
Now let’s look at what an equivalent implementation for DataLoader might look like, built with torchdata.nodes
.
from typing import List, Callable
import torchdata.nodes as tn
from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset
class MapAndCollate:
"""A simple transform that takes a batch of indices, maps with dataset, and then applies
collate.
TODO: make this a standard utility in torchdata.nodes
"""
def __init__(self, dataset, collate_fn):
self.dataset = dataset
self.collate_fn = collate_fn
def __call__(self, batch_of_indices: List[int]):
batch = [self.dataset[i] for i in batch_of_indices]
return self.collate_fn(batch)
# To keep things simple, let's assume that the following args are provided by the caller
def NodesDataLoader(
dataset: Dataset,
batch_size: int,
shuffle: bool,
num_workers: int,
collate_fn: Callable | None,
pin_memory: bool,
drop_last: bool,
):
# Assume we're working with a map-style dataset
assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__")
# Start with a sampler, since caller did not provide one
sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
# Sampler wrapper converts a Sampler to a BaseNode
node = tn.SamplerWrapper(sampler)
# Now let's batch sampler indices together
node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last)
# Create a Map Function that accepts a list of indices, applies getitem to it, and
# then collates them
map_and_collate = MapAndCollate(dataset, collate_fn or default_collate)
# MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could
# choose process or thread workers. Note that if you're not using Free-Threaded
# Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention,
# and slow down training.
node = tn.ParallelMapper(
node,
map_fn=map_and_collate,
num_workers=num_workers,
method="process", # Set this to "thread" for multi-threading
in_order=True,
)
# Optionally apply pin-memory, and we usually do some pre-fetching
if pin_memory:
node = tn.PinMemory(node)
node = tn.Prefetcher(node, prefetch_factor=num_workers * 2)
# Note that node is an iterator, and once it's exhausted, you'll need to call .reset()
# on it to start a new Epoch.
# Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It
# also provides state_dict and load_state_dict methods.
return tn.Loader(node)
Now let’s test this out with a trivial dataset, and demonstrate how state management works.
class SquaredDataset(Dataset):
def __init__(self, len: int):
self.len = len
def __len__(self):
return self.len
def __getitem__(self, i: int) -> int:
return i**2
loader = NodesDataLoader(
dataset=SquaredDataset(14),
batch_size=3,
shuffle=False,
num_workers=2,
collate_fn=None,
pin_memory=False,
drop_last=False,
)
batches = []
for idx, batch in enumerate(loader):
if idx == 2:
state_dict = loader.state_dict()
# Saves the state_dict after batch 2 has been returned
batches.append(batch)
loader.load_state_dict(state_dict)
batches_after_loading = list(loader)
print(batches[3:])
# [tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches_after_loading)
# [tensor([ 81, 100, 121]), tensor([144, 169])]
Let’s also compare this to torch.utils.data.DataLoader, as a sanity check.
loaderv1 = torch.utils.data.DataLoader(
dataset=SquaredDataset(14),
batch_size=3,
shuffle=False,
num_workers=2,
collate_fn=None,
pin_memory=False,
drop_last=False,
persistent_workers=False, # Coming soon to torchdata.nodes!
)
print(list(loaderv1))
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches)
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]
IterableDatasets¶
Coming soon! While you can already plug your IterableDataset into an tn.IterableWrapper
, some functions like
get_worker_info
are not currently supported yet. However we believe that often, sharding work between
multi-process workers is not actually necessary, and you can keep some sort of indexing in the main process while
only parallelizing some of the heavier transforms, similar to how Map-style Datasets work above.