Shortcuts

PPOLoss

class torchtune.rlhf.loss.PPOLoss(epsilon: float = 0.1, value_clip_range: float = 0.2, value_coeff: float = 0.1)[source]

Proximal Policy Optimization (PPO) Loss module. This implementation uses the following references:

https://arxiv.org/abs/1707.06347 eqn. 7

https://github.com/vwxyzjn/lm-human-preference-details/blob/ccc19538e817e98a60d3253242ac15e2a562cb49/lm_human_preference_details/train_policy_accelerate.py#L719

https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L68-L75

Parameters:
  • epsilon (float) – clipping range for PPO update.

  • value_clip_range (float) – clipping range for value function update.

  • value_coeff (float) – coefficient for the value function loss contribution.

forward(pi_old_logprobs: Tensor, pi_logprobs: Tensor, advantages: Tensor, phi_old_values: Tensor, phi_values: Tensor, returns: Tensor, padding_masks: Optional[Tensor] = None, value_padding_masks: Optional[Tensor] = None) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor][source]

Forward pass of the PPO loss module.

Parameters:
  • pi_old_logprobs (torch.Tensor) – Log probabilities of the old policy.

  • pi_logprobs (torch.Tensor) – Log probabilities of the current policy.

  • advantages (torch.Tensor) – Advantage values.

  • phi_old_values (torch.Tensor) – Value predictions of the old value function.

  • phi_values (torch.Tensor) – Value predictions of the current value function.

  • returns (torch.Tensor) – Return values.

  • padding_masks (Optional[torch.Tensor]) – Padding token masks of the same shape as pi_logprobs, where True indicates the corresponding loss values should participage in policy loss calculation.

  • value_padding_masks (Optional[torch.Tensor]) – Padding token masks of the same shape as pi_logprobs, where True indicates the corresponding loss values should participage in value loss calculation.

Returns:

A tuple of five tensors:
  • loss: The total PPO loss.

  • policy_loss: The policy function loss.

  • value_loss: The value function loss.

  • ratios: The ratio between the current and old policy probabilities.

  • clipfrac: The fraction of ratios that were clipped.

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources