• Docs >
  • torchdata.nodes (beta)
Shortcuts

Attention

June 2024 Status Update: Removing DataPipes and DataLoader V2

We are re-focusing the torchdata repo to be an iterative enhancement of torch.utils.data.DataLoader. We do not plan on continuing development or maintaining the [DataPipes] and [DataLoaderV2] solutions, and they will be removed from the torchdata repo. We’ll also be revisiting the DataPipes references in pytorch/pytorch. In release torchdata==0.8.0 (July 2024) they will be marked as deprecated, and in 0.10.0 (Late 2024) they will be deleted. Existing users are advised to pin to torchdata<=0.9.0 or an older version until they are able to migrate away. Subsequent releases will not include DataPipes or DataLoaderV2. Please reach out if you suggestions or comments (please use this issue for feedback)

torchdata.nodes (beta)

class torchdata.nodes.BaseNode(*args, **kwargs)

Bases: Iterator[T]

BaseNodes are the base class for creating composable dataloading DAGs in torchdata.nodes.

Most end-users will not iterate over a BaseNode instance directly, but instead wrap it in a torchdata.nodes.Loader which converts the DAG into a more familiar Iterable.

node = MyBaseNodeImpl()
loader = Loader(node)
# loader supports state_dict() and load_state_dict()

for epoch in range(5):
    for idx, batch in enumerate(loader):
        ...

# or if using node directly:
node = MyBaseNodeImpl()
for epoch in range(5):
    node.reset()
    for idx, batch in enumerate(loader):
        ...
get_state() Dict[str, Any]

Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode. :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future

next() T

Subclasses must implement this method, instead of __next. Should only be called by BaseNode. :return: T - the next value in the sequence, or throw StopIteration

reset(initial_state: Optional[dict] = None)

Resets the iterator to the beginning, or to the state passed in by initial_state.

Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called. Subclasses must call super().reset(initial_state).

Parameters:

initial_state – Optional[dict] - a state dict to pass to the node. If None, reset to the beginning.

state_dict() Dict[str, Any]

Get a state_dict for this BaseNode. :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future.

class torchdata.nodes.Batcher(source: BaseNode[T], batch_size: int, drop_last: bool = True)

Bases: BaseNode[List[T]]

Batcher node batches the data from the source node into batches of size batch_size. If the source node is exhausted, it will return the batch or raise StopIteration. If drop_last is True, the last batch will be dropped if it is smaller than batch_size. If drop_last is False, the last batch will be returned even if it is smaller than batch_size.

Parameters:
  • source (BaseNode[T]) – The source node to batch the data from.

  • batch_size (int) – The size of the batch.

  • drop_last (bool) – Whether to drop the last batch if it is smaller than batch_size. Default is True.

get_state() Dict[str, Any]

Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode. :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future

next() List[T]

Subclasses must implement this method, instead of __next. Should only be called by BaseNode. :return: T - the next value in the sequence, or throw StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

Resets the iterator to the beginning, or to the state passed in by initial_state.

Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called. Subclasses must call super().reset(initial_state).

Parameters:

initial_state – Optional[dict] - a state dict to pass to the node. If None, reset to the beginning.

class torchdata.nodes.IterableWrapper(iterable: Iterable[T])

Bases: BaseNode[T]

Thin Wrapper that converts any Iterable (including torch.utils.data.IterableDataset) in to a BaseNode.

If iterable implements the Stateful Protocol, it will be saved and restored with its state_dict/load_state_dict methods.

Parameters:

iterable (Iterable[T]) – Iterable to convert to BaseNode. IterableWrapper calls iter() on it.

Warning:

Note the distinction between state_dict/load_state_dict defined on Iterable, vs Iterator. Only the Iterable’s state_dict/load_state_dict are used.

get_state() Dict[str, Any]

Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode. :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future

next() T

Subclasses must implement this method, instead of __next. Should only be called by BaseNode. :return: T - the next value in the sequence, or throw StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

Resets the iterator to the beginning, or to the state passed in by initial_state.

Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called. Subclasses must call super().reset(initial_state).

Parameters:

initial_state – Optional[dict] - a state dict to pass to the node. If None, reset to the beginning.

class torchdata.nodes.Loader(root: BaseNode[T], restart_on_stop_iteration: bool = True)

Bases: Generic[T]

Wraps the root BaseNode (an iterator) and provides a stateful iterable interface.

The state of the last-returned iterator is returned by the state_dict() method, and can be loaded using the load_state_dict() method.

Parameters:
  • root (BaseNode[T]) – The root node of the data pipeline.

  • restart_on_stop_iteration (bool) – Whether to restart the iterator when it reaches the end. Default is True

load_state_dict(state_dict: Dict[str, Any])

Loads a state_dict which will be used to initialize the next iter() requested from this loader.

Parameters:

state_dict (Dict[str, Any]) – The state_dict to load. Should be generated from a call to state_dict().

state_dict() Dict[str, Any]

Returns a state_dict which can be passed to load_state_dict() in the future to resume iteration.

The state_dict will come from the iterator returned by the most recent call to iter(). If no iterator has been created, a new iterator will be created and the state_dict returned from it.

torchdata.nodes.MapStyleWrapper(map_dataset: Mapping[K, T], sampler: Sampler[K]) BaseNode[T]

Thin Wrapper that converts any MapDataset in to a torchdata.node If you want parallelism, copy this and replace Mapper with ParallelMapper.

Parameters:
  • map_dataset (Mapping[K, T]) –

    • Apply map_dataset.__getitem__ to the outputs of sampler.

  • sampler (Sampler[K]) –

torchdata.nodes.Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) ParallelMapper[T]

Returns a ParallelMapper node with num_workers=0, which will execute map_fn in the current process/thread.

Parameters:
  • source (BaseNode[X]) – The source node to map over.

  • map_fn (Callable[[X], T]) – The function to apply to each item from the source node.

class torchdata.nodes.MultiNodeWeightedSampler(source_nodes: Mapping[str, BaseNode[T]], weights: Dict[str, float], stop_criteria: str = 'CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED', rank: Optional[int] = None, world_size: Optional[int] = None, seed: int = 0)

Bases: BaseNode[T]

A node that samples from multiple datasets with weights.

This node expects to take in a dictionary of source nodes, and a dictionary of weights. The keys of the source nodes and weights must be the same. The weights are used to sample from the source nodes. We use torch.multinomial to sample from the source nodes, please refer to https://pytorch.org/docs/stable/generated/torch.multinomial.html on how to use weights for sampling. seed is used to initialize the random number generator.

The node implements the state using the following keys: - DATASET_NODE_STATES_KEY: A dictionary of states for each source node. - DATASETS_EXHAUSTED_KEY: A dictionary of booleans indicating whether each source node is exhausted. - EPOCH_KEY: An epoch counter used to initialize the random number generator. - NUM_YIELDED_KEY: The number of items yielded. - WEIGHTED_SAMPLER_STATE_KEY: The state of the weighted sampler.

We support multiple stopping criteria: - CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Cycle through the source nodes until all datasets are exhausted. This is the default behavior. - FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted. - ALL_DATASETS_EXHAUSTED: Stop when all datasets are exhausted.

On complete exhaustion of the source nodes, the node will raise StopIteration.

Parameters:
  • source_nodes (Mapping[str, BaseNode[T]]) – A dictionary of source nodes.

  • weights (Dict[str, float]) – A dictionary of weights for each source node.

  • stop_criteria (str) – The stopping criteria. Default is CYCLE_UNTIL_ALL_DATASETS_EXHAUST

  • rank (int) – The rank of the current process. Default is None, in which case the rank will be obtained from the distributed environment.

  • world_size (int) – The world size of the distributed environment. Default is None, in which case the world size will be obtained from the distributed environment.

  • seed (int) – The seed for the random number generator. Default is 0.

get_state() Dict[str, Any]

Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode. :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future

next() T

Subclasses must implement this method, instead of __next. Should only be called by BaseNode. :return: T - the next value in the sequence, or throw StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

Resets the iterator to the beginning, or to the state passed in by initial_state.

Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called. Subclasses must call super().reset(initial_state).

Parameters:

initial_state – Optional[dict] - a state dict to pass to the node. If None, reset to the beginning.

class torchdata.nodes.ParallelMapper(source: BaseNode[X], map_fn: Callable[[X], T], num_workers: int, in_order: bool = True, method: Literal['thread', 'process'] = 'thread', multiprocessing_context: Optional[str] = None, max_concurrent: Optional[int] = None, snapshot_frequency: int = 1)

Bases: BaseNode[T]

ParallelMapper executes map_fn in parallel either in num_workers threads or processes. For processes, multiprocessing_context can be spawn, forkserver, fork, or None (chooses OS default). At most max_concurrent items will be either processed or in the iterator’s output queue, to limit CPU and Memory utilization. If None (default) the value will be 2 * num_workers.

At most one iter() is created from source, and at most one thread will call next() on it at once.

If in_order is true, the iterator will return items in the order from which they arrive from source’s iterator, potentially blocking even if other items are available.

Parameters:
  • source (BaseNode[X]) – The source node to map over.

  • map_fn (Callable[[X], T]) – The function to apply to each item from the source node.

  • num_workers (int) – The number of workers to use for parallel processing.

  • in_order (bool) – Whether to return items in the order from which they arrive from. Default is True.

  • method (Literal["thread", "process"]) – The method to use for parallel processing. Default is “thread”.

  • multiprocessing_context (Optional[str]) – The multiprocessing context to use for parallel processing. Default is None.

  • max_concurrent (Optional[int]) – The maximum number of items to process at once. Default is None.

  • snapshot_frequency (int) – The frequency at which to snapshot the state of the source node. Default is 1.

get_state() Dict[str, Any]

Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode. :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future

next()

Subclasses must implement this method, instead of __next. Should only be called by BaseNode. :return: T - the next value in the sequence, or throw StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

Resets the iterator to the beginning, or to the state passed in by initial_state.

Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called. Subclasses must call super().reset(initial_state).

Parameters:

initial_state – Optional[dict] - a state dict to pass to the node. If None, reset to the beginning.

class torchdata.nodes.PinMemory(source: BaseNode[T], pin_memory_device: str = '', snapshot_frequency: int = 1)

Bases: BaseNode[T]

Pins the data of the underlying node to a device. This is backed by torch.utils.data._utils.pin_memory._pin_memory_loop.

Parameters:
  • source (BaseNode[T]) – The source node to pin the data from.

  • pin_memory_device (str) – The device to pin the data to. Default is “”.

  • snapshot_frequency (int) – The frequency at which to snapshot the state of the source node. Default is 1, which means that the state of the source node will be snapshotted after every item. If set to a higher value, the state of the source node will be snapshotted after every snapshot_frequency items.

get_state() Dict[str, Any]

Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode. :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future

next()

Subclasses must implement this method, instead of __next. Should only be called by BaseNode. :return: T - the next value in the sequence, or throw StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

Resets the iterator to the beginning, or to the state passed in by initial_state.

Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called. Subclasses must call super().reset(initial_state).

Parameters:

initial_state – Optional[dict] - a state dict to pass to the node. If None, reset to the beginning.

class torchdata.nodes.Prefetcher(source: BaseNode[T], prefetch_factor: int, snapshot_frequency: int = 1)

Bases: BaseNode[T]

Prefetches data from the source node and stores it in a queue.

Parameters:
  • source (BaseNode[T]) – The source node to prefetch data from.

  • prefetch_factor (int) – The number of items to prefetch ahead of time.

  • snapshot_frequency (int) – The frequency at which to snapshot the state of the source node. Default is 1, which means that the state of the source node will be snapshotted after every item. If set to a higher value, the state of the source node will be snapshotted after every snapshot_frequency items.

get_state() Dict[str, Any]

Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode. :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future

next()

Subclasses must implement this method, instead of __next. Should only be called by BaseNode. :return: T - the next value in the sequence, or throw StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

Resets the iterator to the beginning, or to the state passed in by initial_state.

Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called. Subclasses must call super().reset(initial_state).

Parameters:

initial_state – Optional[dict] - a state dict to pass to the node. If None, reset to the beginning.

class torchdata.nodes.SamplerWrapper(sampler: Sampler[T], initial_epoch: int = 0, epoch_updater: Optional[Callable[[int], int]] = None)

Bases: BaseNode[T]

Convert a sampler into a BaseNode. This is nearly identical to IterableWrapper except it includes a hook to call set_epoch on the sampler, if it supports it.

Parameters:
  • sampler (Sampler) – Sampler to wrap.

  • initial_epoch (int) – initial epoch to set on the sampler

  • epoch_updater (Optional[Callable[[int], int]] = None) – callback to update epoch at start of new iteration. It’s called at the beginning of each iterator request, except the first one.

get_state() Dict[str, Any]

Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode. :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future

next() T

Subclasses must implement this method, instead of __next. Should only be called by BaseNode. :return: T - the next value in the sequence, or throw StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

Resets the iterator to the beginning, or to the state passed in by initial_state.

Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called. Subclasses must call super().reset(initial_state).

Parameters:

initial_state – Optional[dict] - a state dict to pass to the node. If None, reset to the beginning.

class torchdata.nodes.Stateful(*args, **kwargs)

Bases: Protocol

Protocol for objects implementing both state_dict() and load_state_dict(state_dict: Dict[str, Any])

class torchdata.nodes.StopCriteria

Bases: object

Stopping criteria for the dataset samplers.

  1. CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Stop once the last unseen dataset is exhausted. All datasets are seen at least once. In certain cases, some datasets may be seen more than once when there are still non-exhausted datasets.

  2. ALL_DATASETS_EXHAUSTED: Stop once all have the datasets are exhausted. Each dataset is seen exactly once. No wraparound or restart will be performed.

  3. FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources