next_state_value¶
- class torchrl.objectives.next_state_value(tensordict: TensorDictBase, operator: Optional[TensorDictModule] = None, next_val_key: str = 'state_action_value', gamma: float = 0.99, pred_next_val: Optional[Tensor] = None, **kwargs)[source]¶
Computes the next state value (without gradient) to compute a target value.
- The target value is ususally used to compute a distance loss (e.g. MSE):
L = Sum[ (q_value - target_value)^2 ]
- The target value is computed as
r + gamma ** n_steps_to_next * value_next_state
If the reward is the immediate reward, n_steps_to_next=1. If N-steps rewards are used, n_steps_to_next is gathered from the input tensordict.
- Parameters:
tensordict (TensorDictBase) – Tensordict containing a reward and done key (and a n_steps_to_next key for n-steps rewards).
operator (ProbabilisticTDModule, optional) – the value function operator. Should write a ‘next_val_key’ key-value in the input tensordict when called. It does not need to be provided if pred_next_val is given.
next_val_key (str, optional) – key where the next value will be written. Default: ‘state_action_value’
gamma (
float
, optional) – return discount rate. default: 0.99pred_next_val (Tensor, optional) – the next state value can be provided if it is not computed with the operator.
- Returns:
a Tensor of the size of the input tensordict containing the predicted value state.