PPOLoss¶
- class torchtune.rlhf.loss.PPOLoss(epsilon: float = 0.1, value_clip_range: float = 0.2, value_coeff: float = 0.1)[source]¶
Proximal Policy Optimization (PPO) Loss module. This implementation uses the following references:
https://arxiv.org/abs/1707.06347 eqn. 7
- Parameters:
- forward(pi_old_logprobs: Tensor, pi_logprobs: Tensor, advantages: Tensor, phi_old_values: Tensor, phi_values: Tensor, returns: Tensor, padding_masks: Optional[Tensor] = None, value_padding_masks: Optional[Tensor] = None) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor] [source]¶
Forward pass of the PPO loss module.
- Parameters:
pi_old_logprobs (torch.Tensor) – Log probabilities of the old policy.
pi_logprobs (torch.Tensor) – Log probabilities of the current policy.
advantages (torch.Tensor) – Advantage values.
phi_old_values (torch.Tensor) – Value predictions of the old value function.
phi_values (torch.Tensor) – Value predictions of the current value function.
returns (torch.Tensor) – Return values.
padding_masks (Optional[torch.Tensor]) – Padding token masks of the same shape as
pi_logprobs
, where True indicates the corresponding loss values should participage in policy loss calculation.value_padding_masks (Optional[torch.Tensor]) – Padding token masks of the same shape as
pi_logprobs
, where True indicates the corresponding loss values should participage in value loss calculation.
- Returns:
- A tuple of five tensors:
loss: The total PPO loss.
policy_loss: The policy function loss.
value_loss: The value function loss.
ratios: The ratio between the current and old policy probabilities.
clipfrac: The fraction of ratios that were clipped.
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]