• Docs >
  • Migrating to torchdata.nodes from torch.utils.data
Shortcuts

Attention

June 2024 Status Update: Removing DataPipes and DataLoader V2

We are re-focusing the torchdata repo to be an iterative enhancement of torch.utils.data.DataLoader. We do not plan on continuing development or maintaining the [DataPipes] and [DataLoaderV2] solutions, and they will be removed from the torchdata repo. We’ll also be revisiting the DataPipes references in pytorch/pytorch. In release torchdata==0.8.0 (July 2024) they will be marked as deprecated, and in 0.10.0 (Late 2024) they will be deleted. Existing users are advised to pin to torchdata<=0.9.0 or an older version until they are able to migrate away. Subsequent releases will not include DataPipes or DataLoaderV2. Please reach out if you suggestions or comments (please use this issue for feedback)

Migrating to torchdata.nodes from torch.utils.data

This guide is intended to help people familiar with torch.utils.data, or StatefulDataLoader, to get started with torchdata.nodes, and provide a starting ground for defining your own dataloading pipelines.

We’ll demonstrate how to achieve the most common DataLoader features, re-use existing samplers and datasets, and load/save dataloader state. It performs at least as well as DataLoader and StatefulDataLoader, see How does torchdata.nodes perform?.

Map-Style Datasets

Let’s look at the DataLoader constructor args and go from there

class DataLoader:
    def __init__(
        self,
        dataset: Dataset[_T_co],
        batch_size: Optional[int] = 1,
        shuffle: Optional[bool] = None,
        sampler: Union[Sampler, Iterable, None] = None,
        batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
        num_workers: int = 0,
        collate_fn: Optional[_collate_fn_t] = None,
        pin_memory: bool = False,
        drop_last: bool = False,
        timeout: float = 0,
        worker_init_fn: Optional[_worker_init_fn_t] = None,
        multiprocessing_context=None,
        generator=None,
        *,
        prefetch_factor: Optional[int] = None,
        persistent_workers: bool = False,
        pin_memory_device: str = "",
        in_order: bool = True,
    ):
        ...

As a referesher, here is roughly how dataloading works in torch.utils.data.DataLoader: DataLoader begins by generating indices from a sampler and creates batches of batch_size indices. If no sampler is provided, then a RandomSampler or SequentialSampler is created by default. The indices are passed to Dataset.__getitem__(), and then a collate_fn is applied to the batch of samples. If num_workers > 0, it will use multi-processing to create subprocesses, and pass the batches of indices to the worker processes, who will then call Dataset.__getitem__() and apply collate_fn before returning the batches to the main process. At that point, pin_memory may be applied to the tensors in the batch.

Now let’s look at what an equivalent implementation for DataLoader might look like, built with torchdata.nodes.

from typing import List, Callable
import torchdata.nodes as tn
from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset

class MapAndCollate:
    """A simple transform that takes a batch of indices, maps with dataset, and then applies
    collate.
    TODO: make this a standard utility in torchdata.nodes
    """
    def __init__(self, dataset, collate_fn):
        self.dataset = dataset
        self.collate_fn = collate_fn

    def __call__(self, batch_of_indices: List[int]):
        batch = [self.dataset[i] for i in batch_of_indices]
        return self.collate_fn(batch)

# To keep things simple, let's assume that the following args are provided by the caller
def NodesDataLoader(
    dataset: Dataset,
    batch_size: int,
    shuffle: bool,
    num_workers: int,
    collate_fn: Callable | None,
    pin_memory: bool,
    drop_last: bool,
):
    # Assume we're working with a map-style dataset
    assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__")
    # Start with a sampler, since caller did not provide one
    sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
    # Sampler wrapper converts a Sampler to a BaseNode
    node = tn.SamplerWrapper(sampler)

    # Now let's batch sampler indices together
    node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last)

    # Create a Map Function that accepts a list of indices, applies getitem to it, and
    # then collates them
    map_and_collate = MapAndCollate(dataset, collate_fn or default_collate)

    # MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could
    # choose process or thread workers. Note that if you're not using Free-Threaded
    # Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention,
    # and slow down training.
    node = tn.ParallelMapper(
        node,
        map_fn=map_and_collate,
        num_workers=num_workers,
        method="process",  # Set this to "thread" for multi-threading
        in_order=True,
    )

    # Optionally apply pin-memory, and we usually do some pre-fetching
    if pin_memory:
        node = tn.PinMemory(node)
    node = tn.Prefetcher(node, prefetch_factor=num_workers * 2)

    # Note that node is an iterator, and once it's exhausted, you'll need to call .reset()
    # on it to start a new Epoch.
    # Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It
    # also provides state_dict and load_state_dict methods.
    return tn.Loader(node)

Now let’s test this out with a trivial dataset, and demonstrate how state management works.

class SquaredDataset(Dataset):
    def __init__(self, len: int):
        self.len = len
    def __len__(self):
        return self.len
    def __getitem__(self, i: int) -> int:
        return i**2

loader = NodesDataLoader(
    dataset=SquaredDataset(14),
    batch_size=3,
    shuffle=False,
    num_workers=2,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
)

batches = []
for idx, batch in enumerate(loader):
    if idx == 2:
        state_dict = loader.state_dict()
        # Saves the state_dict after batch 2 has been returned
    batches.append(batch)

loader.load_state_dict(state_dict)
batches_after_loading = list(loader)
print(batches[3:])
# [tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches_after_loading)
# [tensor([ 81, 100, 121]), tensor([144, 169])]

Let’s also compare this to torch.utils.data.DataLoader, as a sanity check.

loaderv1 = torch.utils.data.DataLoader(
    dataset=SquaredDataset(14),
    batch_size=3,
    shuffle=False,
    num_workers=2,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    persistent_workers=False,  # Coming soon to torchdata.nodes!
)
print(list(loaderv1))
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches)
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]

IterableDatasets

Coming soon! While you can already plug your IterableDataset into an tn.IterableWrapper, some functions like get_worker_info are not currently supported yet. However we believe that often, sharding work between multi-process workers is not actually necessary, and you can keep some sort of indexing in the main process while only parallelizing some of the heavier transforms, similar to how Map-style Datasets work above.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources