Shortcuts

RSOLoss

class torchtune.rlhf.loss.RSOLoss(gamma: float = 0.1)[source]

Statistical Rejection Sampling Optimization (RSO) or “hinge” loss module: https://arxiv.org/abs/2309.06657. Intuition from the paper:

DPO is a logistic regression on human preference data, and SLiC (https://arxiv.org/abs/2305.10425) is almost equivalent to a support vector machine (SVM) with hinge loss. [RSO] improve[s] SLiC as the SVM counter part of DPO.

Based on the implementation in HF’s TRL library: https://github.com/huggingface/trl/blob/4dce042a3863db1d375358e8c8092b874b02934b/trl/trainer/dpo_trainer.py#L1141

Parameters:

gamma (float) – Equivalent temperature parameter (from DPO) for the RSO loss.

forward(policy_chosen_logps: Tensor, policy_rejected_logps: Tensor, reference_chosen_logps: Tensor, reference_rejected_logps: Tensor) Tuple[Tensor, Tensor, Tensor][source]

Compute the RSO loss for a batch of policy and reference model log probabilities.

Parameters:
  • policy_chosen_logps (torch.Tensor) – Log probabilities of the policy model for the chosen responses. Shape: (batch_size)

  • policy_rejected_logps (torch.Tensor) – Log probabilities of the policy model for the rejected responses. Shape: (batch_size)

  • reference_chosen_logps (torch.Tensor) – Log probabilities of the reference model for the chosen responses. Shape: (batch_size)

  • reference_rejected_logps (torch.Tensor) – Log probabilities of the reference model for the rejected responses. Shape: (batch_size)

Returns:

A tuple of three tensors:
  • losses: The RSO loss for each example in the batch.

  • chosen_rewards: Rewards for the chosen responses.

  • rejected_rewards: Rewards for the rejected responses.

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources