.. _migrate-to-nodes-from-utils: Migrating to ``torchdata.nodes`` from ``torch.utils.data`` ========================================================== This guide is intended to help people familiar with ``torch.utils.data``, or :class:`~torchdata.stateful_dataloader.StatefulDataLoader`, to get started with ``torchdata.nodes``, and provide a starting ground for defining your own dataloading pipelines. We'll demonstrate how to achieve the most common DataLoader features, re-use existing samplers and datasets, and load/save dataloader state. It performs at least as well as ``DataLoader`` and ``StatefulDataLoader``, see :ref:`how-does-nodes-perform`. Map-Style Datasets ~~~~~~~~~~~~~~~~~~ Let's look at the ``DataLoader`` constructor args and go from there .. code:: python class DataLoader: def __init__( self, dataset: Dataset[_T_co], batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None, batch_sampler: Union[Sampler[List], Iterable[List], None] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: str = "", in_order: bool = True, ): ... As a referesher, here is roughly how dataloading works in ``torch.utils.data.DataLoader``: ``DataLoader`` begins by generating indices from a ``sampler`` and creates batches of `batch_size` indices. If no sampler is provided, then a RandomSampler or SequentialSampler is created by default. The indices are passed to ``Dataset.__getitem__()``, and then a ``collate_fn`` is applied to the batch of samples. If ``num_workers > 0``, it will use multi-processing to create subprocesses, and pass the batches of indices to the worker processes, who will then call ``Dataset.__getitem__()`` and apply ``collate_fn`` before returning the batches to the main process. At that point, ``pin_memory`` may be applied to the tensors in the batch. Now let's look at what an equivalent implementation for DataLoader might look like, built with ``torchdata.nodes``. .. code:: python from typing import List, Callable import torchdata.nodes as tn from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset class MapAndCollate: """A simple transform that takes a batch of indices, maps with dataset, and then applies collate. TODO: make this a standard utility in torchdata.nodes """ def __init__(self, dataset, collate_fn): self.dataset = dataset self.collate_fn = collate_fn def __call__(self, batch_of_indices: List[int]): batch = [self.dataset[i] for i in batch_of_indices] return self.collate_fn(batch) # To keep things simple, let's assume that the following args are provided by the caller def NodesDataLoader( dataset: Dataset, batch_size: int, shuffle: bool, num_workers: int, collate_fn: Callable | None, pin_memory: bool, drop_last: bool, ): # Assume we're working with a map-style dataset assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__") # Start with a sampler, since caller did not provide one sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset) # Sampler wrapper converts a Sampler to a BaseNode node = tn.SamplerWrapper(sampler) # Now let's batch sampler indices together node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last) # Create a Map Function that accepts a list of indices, applies getitem to it, and # then collates them map_and_collate = MapAndCollate(dataset, collate_fn or default_collate) # MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could # choose process or thread workers. Note that if you're not using Free-Threaded # Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention, # and slow down training. node = tn.ParallelMapper( node, map_fn=map_and_collate, num_workers=num_workers, method="process", # Set this to "thread" for multi-threading in_order=True, ) # Optionally apply pin-memory, and we usually do some pre-fetching if pin_memory: node = tn.PinMemory(node) node = tn.Prefetcher(node, prefetch_factor=num_workers * 2) # Note that node is an iterator, and once it's exhausted, you'll need to call .reset() # on it to start a new Epoch. # Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It # also provides state_dict and load_state_dict methods. return tn.Loader(node) Now let's test this out with a trivial dataset, and demonstrate how state management works. .. code:: python class SquaredDataset(Dataset): def __init__(self, len: int): self.len = len def __len__(self): return self.len def __getitem__(self, i: int) -> int: return i**2 loader = NodesDataLoader( dataset=SquaredDataset(14), batch_size=3, shuffle=False, num_workers=2, collate_fn=None, pin_memory=False, drop_last=False, ) batches = [] for idx, batch in enumerate(loader): if idx == 2: state_dict = loader.state_dict() # Saves the state_dict after batch 2 has been returned batches.append(batch) loader.load_state_dict(state_dict) batches_after_loading = list(loader) print(batches[3:]) # [tensor([ 81, 100, 121]), tensor([144, 169])] print(batches_after_loading) # [tensor([ 81, 100, 121]), tensor([144, 169])] Let's also compare this to torch.utils.data.DataLoader, as a sanity check. .. code:: python loaderv1 = torch.utils.data.DataLoader( dataset=SquaredDataset(14), batch_size=3, shuffle=False, num_workers=2, collate_fn=None, pin_memory=False, drop_last=False, persistent_workers=False, # Coming soon to torchdata.nodes! ) print(list(loaderv1)) # [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])] print(batches) # [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])] IterableDatasets ~~~~~~~~~~~~~~~~ Coming soon! While you can already plug your IterableDataset into an ``tn.IterableWrapper``, some functions like ``get_worker_info`` are not currently supported yet. However we believe that often, sharding work between multi-process workers is not actually necessary, and you can keep some sort of indexing in the main process while only parallelizing some of the heavier transforms, similar to how Map-style Datasets work above.