Shortcuts

padded_collate

torchtune.utils.padded_collate(batch: List[Dict[str, List[int]]], padding_idx: int = 0, ignore_idx: int = - 100) Dict[str, Tensor][source]

Pad a batch of sequences to the longest sequence length in the batch, and convert integer lists to tensors.

Parameters:
  • batch (List[Dict[str, List[int]]]) – A list of tuples containing input, label pairs.

  • padding_idx (int) – Padding index for input ids. Defaults to 0.

  • ignore_idx (int) – Padding index for labels. Defaults to -100.

Returns:

Collated input and label tensors.

Return type:

Dict[str, torch.Tensor]

Example

>>> token_pairs = [
>>>    {"tokens": [1, 2, 3], "labels": [4, 5, 6]},
>>>    {"tokens": [7,], "labels": [10,]},
>>> ]
>>> collated = padded_collate(
>>>    batch=token_pairs,
>>>    padding_idx=padding_idx,
>>>    ignore_idx=ignore_idx,
>>> )
>>> collated["tokens"]
>>> tensor([[1, 2, 3], [7, 0, 0]])
>>> collated["labels"]
>>> tensor([[4, 5, 6], [10, -100, -100]])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources