Shortcuts

padded_collate_dpo

torchtune.utils.padded_collate_dpo(batch: List[Dict[str, List[int]]], padding_idx: int = 0, ignore_idx: int = - 100) Tuple[Tensor, Tensor][source]

Pad a batch of sequences for Direct Preference Optimization (DPO).

This function takes a batch of sequences, where each sequence is represented as a dictionary with multiple key-value pairs. Each key corresponds to a different sequence component, such as input_ids or labels.

Parameters:
  • batch (List[Dict[str, List[int]]]) – A list of dictionaries, where each dictionary represents a sequence with multiple components, ‘chosen_input_ids’, ‘chosen_labels’, ‘rejected_input_ids’, and ‘rejected_labels’ are required.

  • padding_idx (int) – Padding index for input ids. Defaults to 0.

  • ignore_idx (int) – Padding index for labels. Defaults to -100.

Returns:

A tuple containing concatenated and padded input ids and labels.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Raises:
  • AssertionError – if the length of chosen_input_ids and rejected_input_ids differ.

  • AssertionError – if the length of chosen_labels and rejected_labels differ.

Example

>>> batch = [
>>>    {'chosen_input_ids': [1, 2, 3], 'rejected_input_ids': [4, 5],
>>>      'chosen_labels': [6, 7, 8], 'rejected_labels': [9, 10]},
>>>    {'chosen_input_ids': [11, 12], 'rejected_input_ids': [13, 14, 15],
>>>      'chosen_labels': [16, 17], 'rejected_labels': [18, 19, 20]},
>>> ]
>>> padded_collate_dpo(batch)
>>> (tensor([[ 1,  2,  3],
>>>          [11, 12,  0],
>>>          [ 4,  5,  0],
>>>          [13, 14, 15]]),
>>>  tensor([[ 6,  7,  8],
>>>          [16, 17, -100],
>>>          [ 9, 10, -100],
>>>          [18, 19, 20]]))

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources