Shortcuts

Grouper

class torchdata.datapipes.iter.Grouper(datapipe: IterDataPipe[T_co], group_key_fn: Callable, *, buffer_size: int = 10000, group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None, drop_remaining: bool = False)

Groups data from input IterDataPipe by keys which are generated from group_key_fn, and yields a DataChunk with batch size up to group_size if defined (functional name: groupby).

The samples are read sequentially from the source datapipe, and a batch of samples belonging to the same group will be yielded as soon as the size of the batch reaches group_size. When the buffer is full, the DataPipe will yield the largest batch with the same key, provided that its size is larger than guaranteed_group_size. If its size is smaller, it will be dropped if drop_remaining=True.

After iterating through the entirety of source datapipe, everything not dropped due to the buffer capacity will be yielded from the buffer, even if the group sizes are smaller than guaranteed_group_size.

Parameters:
  • datapipe – Iterable datapipe to be grouped

  • group_key_fn – Function used to generate group key from the data of the source datapipe

  • buffer_size – The size of buffer for ungrouped data

  • group_size – The max size of each group, a batch is yielded as soon as it reaches this size

  • guaranteed_group_size – The guaranteed minimum group size to be yielded in case the buffer is full

  • drop_remaining – Specifies if the group smaller than guaranteed_group_size will be dropped from buffer when the buffer is full

Example

>>> import os
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def group_fn(file):
...    return os.path.basename(file).split(".")[0]
>>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
>>> dp0 = source_dp.groupby(group_key_fn=group_fn)
>>> list(dp0)
[['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
>>> # A group is yielded as soon as its size equals to `group_size`
>>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
>>> list(dp1)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
>>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
>>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
>>> list(dp2)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources