Shortcuts

MaxTokenBucketizer

class torchdata.datapipes.iter.MaxTokenBucketizer(datapipe: ~IterDataPipe[~T_co], max_token_count: int, len_fn: ~Callable = <function _default_len_fn>, min_len: int = 0, max_len: ~Optional[int] = None, buffer_size: int = 1000, include_padding: bool = False)

Creates mini-batches of data from a min-heap with limited size, and the total length of samples returned by len_fn within each batch will be limited by max_token_count (functional name: max_token_bucketize). If min_len or max_len is set, the samples with length that is out of [min_len, max_len] will be filtered out.

The purpose of this DataPipe is to batch samples with similar length according to len_fn. Min-heap is used here to make sure the samples are sorted incrementally based on the length. And, the total length of samples in each batch is guaranteed to be smaller than max_token_count. For an example in the audio domain, it may be batching samples with similar length. Then, given the max_token_count, each batch may be concatenated to a Tensor with the same size and minimum padding.

If include_padding is set to True, the token count of each batch includes the padding a succeeding DataPipe could add. This guarentees that even after the batch is padded, max_token_count will not be exceeded. This can prevent out-of-memory issues for data with large variations in length.

Note that batches are bucketized starting from the smallest size in a buffer. This can limit the variablity of batches if buffer_size is large. To increase variablity, apply torchdata.datapipes.iter.Shuffler before and after this DataPipe, and keep buffer_size small.

Parameters:
  • datapipe – Iterable DataPipe being batched

  • max_token_count – Maximum length of total length of data in each batch

  • len_fn – Function to be applied to each element to get lengths. len(data) is used by default.

  • min_len – Optional minimum length to be included into each batch

  • max_len – Optional maximum length to be included into each batch.

  • buffer_size – This restricts how many samples are taken from prior DataPipe to bucketize

  • include_padding – If True, the size of each batch includes the extra padding to the largest length in the batch.

Example

>>> from torchdata.datapipes.iter import IterableWrapper
>>> source_dp = IterableWrapper(['1', '11', '1', '1111', '111', '1', '11', '11', '111'])
>>> # Using default len_fn to sort samples based on length (string length in this case)
>>> batch_dp = source_dp.max_token_bucketize(max_token_count=5)
>>> list(batch_dp)
[['1', '1', '1', '11'], ['11', '11'], ['111'], ['111'], ['1111']]
>>> batch_dp = source_dp.max_token_bucketize(max_token_count=4, buffer_size=4)
>>> list(batch_dp)
[['1', '1', '1'], ['11', '11'], ['11'], ['111'], ['111'], ['1111']]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources