Shortcuts

DPOLoss

class torchtune.rlhf.loss.DPOLoss(beta: float = 0.1, label_smoothing: float = 0.0)[source]

Direct Preference Optimization (DPO) Loss module: https://arxiv.org/abs/2305.18290 Simply stated from the paper:

Intuitively, the DPO update increases the relative log probability of preferred to dispreferred responses, but it incorporates a dynamic, per-example importance weight that prevents the model degeneration that we find occurs with a naive probability ratio objective.

Based on the implementation in HF’s TRL library: https://github.com/huggingface/trl/blob/5d1deb1445828cfd0e947cb3a7925b1c03a283fc/trl/trainer/dpo_trainer.py#L844

DPO retains similarities to PPO (https://arxiv.org/abs/2009.01325), where it optimizes a policy (language) model to align with human preferences, and regularizes the loss function using a baseline reference (the frozen, initial language model) to prevent over-fitting to the preference dataset. It differs from PPO by optimizing the policy model directly using labelled preference data, rather than using an additional reward model to provide feedback. This significantly simplifies training and reduces compute overhead.

Parameters:
  • beta (float) – Temperature parameter for the DPO loss, typically in the range of 0.1 to 0.5. Default is 0.1.

  • label_smoothing (float) – Parameter encoding uncertainty about the labels. Default is 0.

forward(policy_chosen_logps: Tensor, policy_rejected_logps: Tensor, reference_chosen_logps: Tensor, reference_rejected_logps: Tensor) Tuple[Tensor, Tensor, Tensor][source]

Compute the DPO loss for a batch of policy and reference model log probabilities.

Parameters:
  • policy_chosen_logps (torch.Tensor) – Log probabilities of the policy model for the chosen responses. Shape: (batch_size)

  • policy_rejected_logps (torch.Tensor) – Log probabilities of the policy model for the rejected responses. Shape: (batch_size)

  • reference_chosen_logps (torch.Tensor) – Log probabilities of the reference model for the chosen responses. Shape: (batch_size)

  • reference_rejected_logps (torch.Tensor) – Log probabilities of the reference model for the rejected responses. Shape: (batch_size)

Returns:

A tuple of three tensors:
  • losses: The DPO loss for each example in the batch.

  • chosen_rewards: Rewards for the chosen responses.

  • rejected_rewards: Rewards for the rejected responses.

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources