padded_collate_dpo¶
- torchtune.utils.padded_collate_dpo(batch: List[Dict[str, List[int]]], padding_idx: int = 0, ignore_idx: int = - 100) Tuple[Tensor, Tensor] [source]¶
Pad a batch of sequences for Direct Preference Optimization (DPO).
This function takes a batch of sequences, where each sequence is represented as a dictionary with multiple key-value pairs. Each key corresponds to a different sequence component, such as input_ids or labels.
- Parameters:
batch (List[Dict[str, List[int]]]) – A list of dictionaries, where each dictionary represents a sequence with multiple components, ‘chosen_input_ids’, ‘chosen_labels’, ‘rejected_input_ids’, and ‘rejected_labels’ are required.
padding_idx (int) – Padding index for input ids. Defaults to 0.
ignore_idx (int) – Padding index for labels. Defaults to -100.
- Returns:
A tuple containing concatenated and padded input ids and labels.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- Raises:
AssertionError – if the length of chosen_input_ids and rejected_input_ids differ.
AssertionError – if the length of chosen_labels and rejected_labels differ.
Example
>>> batch = [ >>> {'chosen_input_ids': [1, 2, 3], 'rejected_input_ids': [4, 5], >>> 'chosen_labels': [6, 7, 8], 'rejected_labels': [9, 10]}, >>> {'chosen_input_ids': [11, 12], 'rejected_input_ids': [13, 14, 15], >>> 'chosen_labels': [16, 17], 'rejected_labels': [18, 19, 20]}, >>> ] >>> padded_collate_dpo(batch) >>> (tensor([[ 1, 2, 3], >>> [11, 12, 0], >>> [ 4, 5, 0], >>> [13, 14, 15]]), >>> tensor([[ 6, 7, 8], >>> [16, 17, -100], >>> [ 9, 10, -100], >>> [18, 19, 20]]))