estimate_advantages¶
- torchtune.rlhf.estimate_advantages(values: Tensor, rewards: Tensor, gamma: float, lmbda: float, masks: Optional[Tensor] = None) Tuple[Tensor, Tensor] [source]¶
Estimates the advantages and returns for the PPO algorithm using Generalized Advantage Estimation https://arxiv.org/pdf/1506.02438.pdf
- Parameters:
values (torch.Tensor) – The predicted values for each state. Shape:
(b, response_len)
rewards (torch.Tensor) – The rewards received at each time step. Shape:
(b, response_len)
gamma (float) – The discount factor.
lmbda (float) – The GAE-Lambda parameter.
masks (Optional[torch.Tensor]) – A bool mask tensor, where True indicates the corresponding value in
values
should participate in the mean calculation. Default None.
- Returns:
- A tuple containing the estimated advantages and returns.
advantages (torch.Tensor): The estimated advantages. Shape:
(b, response_len)
returns (torch.Tensor): The estimated returns. Shape:
(b, response_len)
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- Notation:
b: batch size
response_len: model response length