Shortcuts

padded_collate

torchtune.data.padded_collate(batch: List[Dict[str, List[int]]], *, pad_direction: str, keys_to_pad: List[str], padding_idx: Union[int, Dict[str, int]])[source]

A generic padding collation function which pads keys_to_pad entries in a batch of sequences from the given pad_direction to the maximum sequence length for each entry in the batch.

Note

This function assumes all batch elements which are not in keys_to_pad do not require any collation (see example below).

Parameters:
  • batch (List[Dict[str, List[int]]]) – A list of dictionaries containing inputs.

  • pad_direction (str) – whether to pad entries from the left, or right. If pad_direction="right", we use torch.nn.utils.rnn.pad_sequence(), otherwise if pad_direction="left", we use torchtune.data.left_pad_sequence().

  • keys_to_pad (List[str]) – Batch element keys to apply padding to. Should be a subset of keys in the batch.

  • padding_idx (Union[int, Dict[str, int]]) – Either a single integer padding value to apply to all keys_to_pad elements, or a mapping with keys identical to keys_to_pad with per-key padding values.

Returns:

The padded tensor of input ids with shape [batch_size, max_seq_len].

Return type:

torch.Tensor

Raises:
  • ValueError – if pad_direction is not one of “left” or “right”.

  • ValueError – if keys_to_pad is empty, or is not a list, or is not a subset of keys in the batch.

  • ValueError – if padding_idx is provided as a dictionary, but the keys are not identical to keys_to_pad.

Example

>>> a = [1, 2, 3]
>>> b = [4, 5, 6, 7]
>>> c = [8, 9, 10, 11, 12]
>>> batch = [
>>>     {"tokens": a, "labels": 1},
>>>     {"tokens": b, "labels": 3},
>>>     {"tokens": c, "labels": 0},
>>> ]
>>> padded_collate(
>>>     batch,
>>>     pad_direction="left",
>>>     keys_to_pad=["tokens"],
>>>     padding_idx=-10
>>> )
{
    'labels': tensor([1, 3, 0]),
    'tokens': tensor([[-10, -10,   1,   2,   3],
                      [-10,   4,   5,   6,   7],
                      [  8,   9,  10,  11,  12]])
}

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources