Shortcuts

next_state_value

class torchrl.objectives.next_state_value(tensordict: TensorDictBase, operator: Optional[TensorDictModule] = None, next_val_key: str = 'state_action_value', gamma: float = 0.99, pred_next_val: Optional[Tensor] = None, **kwargs)[source]

Computes the next state value (without gradient) to compute a target value.

The target value is ususally used to compute a distance loss (e.g. MSE):

L = Sum[ (q_value - target_value)^2 ]

The target value is computed as

r + gamma ** n_steps_to_next * value_next_state

If the reward is the immediate reward, n_steps_to_next=1. If N-steps rewards are used, n_steps_to_next is gathered from the input tensordict.

Parameters:
  • tensordict (TensorDictBase) – Tensordict containing a reward and done key (and a n_steps_to_next key for n-steps rewards).

  • operator (ProbabilisticTDModule, optional) – the value function operator. Should write a ‘next_val_key’ key-value in the input tensordict when called. It does not need to be provided if pred_next_val is given.

  • next_val_key (str, optional) – key where the next value will be written. Default: ‘state_action_value’

  • gamma (float, optional) – return discount rate. default: 0.99

  • pred_next_val (Tensor, optional) – the next state value can be provided if it is not computed with the operator.

Returns:

a Tensor of the size of the input tensordict containing the predicted value state.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources