MaxTokenBucketizer¶
- class torchdata.datapipes.iter.MaxTokenBucketizer(datapipe: ~IterDataPipe[~T_co], max_token_count: int, len_fn: ~Callable = <function _default_len_fn>, min_len: int = 0, max_len: ~Optional[int] = None, buffer_size: int = 1000, include_padding: bool = False)¶
Creates mini-batches of data from a min-heap with limited size, and the total length of samples returned by
len_fn
within each batch will be limited bymax_token_count
(functional name:max_token_bucketize
). Ifmin_len
ormax_len
is set, the samples with length that is out of[min_len, max_len]
will be filtered out.The purpose of this DataPipe is to batch samples with similar length according to
len_fn
. Min-heap is used here to make sure the samples are sorted incrementally based on the length. And, the total length of samples in each batch is guaranteed to be smaller thanmax_token_count
. For an example in the audio domain, it may be batching samples with similar length. Then, given themax_token_count
, each batch may be concatenated to a Tensor with the same size and minimum padding.If
include_padding
is set toTrue
, the token count of each batch includes the padding a succeeding DataPipe could add. This guarentees that even after the batch is padded,max_token_count
will not be exceeded. This can prevent out-of-memory issues for data with large variations in length.Note that batches are bucketized starting from the smallest size in a buffer. This can limit the variablity of batches if
buffer_size
is large. To increase variablity, applytorchdata.datapipes.iter.Shuffler
before and after this DataPipe, and keepbuffer_size
small.- Parameters:
datapipe – Iterable DataPipe being batched
max_token_count – Maximum length of total length of data in each batch
len_fn – Function to be applied to each element to get lengths.
len(data)
is used by default.min_len – Optional minimum length to be included into each batch
max_len – Optional maximum length to be included into each batch.
buffer_size – This restricts how many samples are taken from prior DataPipe to bucketize
include_padding – If True, the size of each batch includes the extra padding to the largest length in the batch.
Example
>>> from torchdata.datapipes.iter import IterableWrapper >>> source_dp = IterableWrapper(['1', '11', '1', '1111', '111', '1', '11', '11', '111']) >>> # Using default len_fn to sort samples based on length (string length in this case) >>> batch_dp = source_dp.max_token_bucketize(max_token_count=5) >>> list(batch_dp) [['1', '1', '1', '11'], ['11', '11'], ['111'], ['111'], ['1111']] >>> batch_dp = source_dp.max_token_bucketize(max_token_count=4, buffer_size=4) >>> list(batch_dp) [['1', '1', '1'], ['11', '11'], ['11'], ['111'], ['111'], ['1111']]