DPOLoss¶
- class torchtune.modules.loss.DPOLoss(beta: float = 0.1, label_smoothing: float = 0.0, loss_type: str = 'sigmoid')[source]¶
Direct Preference Optimization (DPO) Loss module: https://arxiv.org/abs/2305.18290.
Based on the implementation in HF’s TRL library: https://github.com/huggingface/trl/blob/5d1deb1445828cfd0e947cb3a7925b1c03a283fc/trl/trainer/dpo_trainer.py#L844
- Parameters:
beta (float) – Temperature parameter for the DPO loss, typically in the range of 0.1 to 0.5. Default is 0.1.
label_smoothing (float) – Parameter encoding uncertainty about the labels. Default is 0.
loss_type (str) – Type of loss function to be used. Should be one of [‘sigmoid’, ‘hinge’, ‘ipo’, ‘kto_pair’].
- forward(policy_chosen_logps: Tensor, policy_rejected_logps: Tensor, reference_chosen_logps: Tensor, reference_rejected_logps: Tensor) Tuple[Tensor, Tensor, Tensor] [source]¶
Compute the DPO loss for a batch of policy and reference model log probabilities.
- Parameters:
policy_chosen_logps (torch.Tensor) – Log probabilities of the policy model for the chosen responses. Shape: (batch_size)
policy_rejected_logps (torch.Tensor) – Log probabilities of the policy model for the rejected responses. Shape: (batch_size)
reference_chosen_logps (torch.Tensor) – Log probabilities of the reference model for the chosen responses. Shape: (batch_size)
reference_rejected_logps (torch.Tensor) – Log probabilities of the reference model for the rejected responses. Shape: (batch_size)
- Returns:
- A tuple of three tensors:
losses: The DPO loss for each example in the batch.
chosen_rewards: Rewards for the chosen responses.
rejected_rewards: Rewards for the rejected responses.
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- Raises:
ValueError – If an unknown loss type is specified.