MaxTokenBucketizer¶
- class torchdata.datapipes.iter.MaxTokenBucketizer(datapipe: ~IterDataPipe[~T_co], max_token_count: int, len_fn: ~Callable = <function _default_len_fn>, min_len: int = 0, max_len: ~Optional[int] = None, buffer_size: int = 1000)¶
Creates mini-batches of data from a min-heap with limited size, and the total length of samples returned by
len_fn
within each batch will be limited bymax_token_count
(functional name:max_token_bucketize
). Ifmin_len
ormax_len
is set, the samples with length that is out of[min_len, max_len]
will be filtered out.The purpose of this DataPipe is to batch samples with similar length according to
len_fn
. Min-heap is used here to make sure the samples are sorted incrementally based on the length. And, the total length of samples in each batch is guaranteed to be smaller thanmax_token_count
. For an example in the audio domain, it may be batching samples with similar length. Then, given themax_token_count
, each batch may be concatenated to a Tensor with the same size and minimum padding.- Parameters:
datapipe – Iterable DataPipe being batched
max_token_count – Maximum length of total length of data in each batch
len_fn – Function to be applied to each element to get lengths.
len(data)
is used by default.min_len – Optional minimum length to be included into each batch
max_len – Optional maximum length to be included into each batch.
buffer_size – This restricts how many tokens are taken from prior DataPipe to bucketize
Example
>>> from torchdata.datapipes.iter import IterableWrapper >>> source_dp = IterableWrapper(['1', '11', '1', '1111', '111', '1', '11', '11', '111']) >>> # Using default len_fn to sort samples based on length (string length in this case) >>> batch_dp = source_dp.max_token_bucketize(max_token_count=5) >>> list(batch_dp) [['1', '1', '1', '11'], ['11', '11'], ['111'], ['111'], ['1111']] >>> batch_dp = source_dp.max_token_bucketize(max_token_count=4, buffer_size=4) >>> list(batch_dp) [['1', '1', '1'], ['11', '11'], ['11'], ['111'], ['111'], ['1111']]