IterKeyZipper¶
- class torchdata.datapipes.iter.IterKeyZipper(source_datapipe: IterDataPipe, ref_datapipe: IterDataPipe, key_fn: Callable, ref_key_fn: Optional[Callable] = None, keep_key: bool = False, buffer_size: int = 10000, merge_fn: Optional[Callable] = None)¶
Zips two IterDataPipes together based on the matching key (functional name:
zip_with_iter
). The keys are computed bykey_fn
andref_key_fn
for the two IterDataPipes, respectively. When there isn’t a match between the elements of the two IterDataPipes, the element fromref_datapipe
is stored in a buffer. Then, the next element fromref_datapipe
is tried. After a match is found, themerge_fn
determines how they will be combined and returned (a tuple is generated by default).- Parameters:
source_datapipe – IterKeyZipper will yield data based on the order of this IterDataPipe
ref_datapipe – Reference IterDataPipe from which IterKeyZipper will find items with matching key for
source_datapipe
key_fn – Callable function that will compute keys using elements from
source_datapipe
ref_key_fn – Callable function that will compute keys using elements from
ref_datapipe
If it’s not specified, thekey_fn
will also be applied to elements fromref_datapipe
keep_key – Option to yield the matching key along with the items in a tuple, resulting in (key, merge_fn(item1, item2)).
buffer_size – The size of buffer used to hold key-data pairs from reference DataPipe until a match is found. If it’s specified as
None
, the buffer size is set as infinite.merge_fn – Function that combines the item from
source_datapipe
and the item fromref_datapipe
, by default a tuple is created
Example
>>> from torchdata.datapipes.iter import IterableWrapper >>> from operator import itemgetter >>> def merge_fn(t1, t2): >>> return t1[1] + t2[1] >>> dp1 = IterableWrapper([('a', 100), ('b', 200), ('c', 300)]) >>> dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 3), ('d', 4)]) >>> res_dp = dp1.zip_with_iter(dp2, key_fn=itemgetter(0), >>> ref_key_fn=itemgetter(0), keep_key=True, merge_fn=merge_fn) >>> list(res_dp) [('a', 101), ('b', 202), ('c', 303)]