Shortcuts

ValueEstimatorBase

class torchrl.objectives.value.ValueEstimatorBase(*args, **kwargs)[source]

An abstract parent class for value function modules.

Its ValueFunctionBase.forward() method will compute the value (given by the value network) and the value estimate (given by the value estimator) as well as the advantage and write these values in the output tensordict.

If only the value estimate is needed, the ValueFunctionBase.value_estimate() should be used instead.

abstract forward(tensordict: TensorDictBase, *, params: Optional[TensorDictBase] = None, target_params: Optional[TensorDictBase] = None) TensorDictBase[source]

Computes the advantage estimate given the data in tensordict.

If a functional module is provided, a nested TensorDict containing the parameters (and if relevant the target parameters) can be passed to the module.

Parameters:

tensordict (TensorDictBase) – A TensorDict containing the data (an observation key, "action", ("next", "reward"), ("next", "done"), ("next", "terminated"), and "next" tensordict state as returned by the environment) necessary to compute the value estimates and the TDEstimate. The data passed to this module should be structured as [*B, T, *F] where B are the batch size, T the time dimension and F the feature dimension(s). The tensordict must have shape [*B, T].

Keyword Arguments:
  • params (TensorDictBase, optional) – A nested TensorDict containing the params to be passed to the functional value network module.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

  • device (torch.device, optional) – the device where the buffers will be instantiated. Defaults to torch.get_default_device().

Returns:

An updated TensorDict with an advantage and a value_error keys as defined in the constructor.

set_keys(**kwargs) None[source]

Set tensordict key names.

value_estimate(tensordict, target_params: Optional[TensorDictBase] = None, next_value: Optional[Tensor] = None, **kwargs)[source]

Gets a value estimate, usually used as a target value for the value network.

If the state value key is present under tensordict.get(("next", self.tensor_keys.value)) then this value will be used without recurring to the value network.

Parameters:
  • tensordict (TensorDictBase) – the tensordict containing the data to read.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

  • next_value (torch.Tensor, optional) – the value of the next state or state-action pair. Exclusive with target_params.

  • **kwargs – the keyword arguments to be passed to the value network.

Returns: a tensor corresponding to the state value.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources