Grouper¶
- class torchdata.datapipes.iter.Grouper(datapipe: IterDataPipe[T_co], group_key_fn: Callable, *, buffer_size: int = 10000, group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None, drop_remaining: bool = False)¶
Groups data from input IterDataPipe by keys which are generated from
group_key_fn
, and yields aDataChunk
with batch size up togroup_size
if defined (functional name:groupby
).The samples are read sequentially from the source
datapipe
, and a batch of samples belonging to the same group will be yielded as soon as the size of the batch reachesgroup_size
. When the buffer is full, the DataPipe will yield the largest batch with the same key, provided that its size is larger thanguaranteed_group_size
. If its size is smaller, it will be dropped ifdrop_remaining=True
.After iterating through the entirety of source
datapipe
, everything not dropped due to the buffer capacity will be yielded from the buffer, even if the group sizes are smaller thanguaranteed_group_size
.- Parameters:
datapipe – Iterable datapipe to be grouped
group_key_fn – Function used to generate group key from the data of the source datapipe
buffer_size – The size of buffer for ungrouped data
group_size – The max size of each group, a batch is yielded as soon as it reaches this size
guaranteed_group_size – The guaranteed minimum group size to be yielded in case the buffer is full
drop_remaining – Specifies if the group smaller than
guaranteed_group_size
will be dropped from buffer when the buffer is full
Example
>>> import os >>> # xdoctest: +SKIP >>> from torchdata.datapipes.iter import IterableWrapper >>> def group_fn(file): ... return os.path.basename(file).split(".")[0] >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]) >>> dp0 = source_dp.groupby(group_key_fn=group_fn) >>> list(dp0) [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] >>> # A group is yielded as soon as its size equals to `group_size` >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2) >>> list(dp1) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2) >>> list(dp2) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]