Shortcuts

GAE

class torchrl.objectives.value.GAE(*args, **kwargs)[source]

A class wrapper around the generalized advantage estimate functional.

Refer to “HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION” https://arxiv.org/pdf/1506.02438.pdf for more context.

Parameters:
  • gamma (scalar) – exponential mean discount.

  • lmbda (scalar) – trajectory discount.

  • value_network (TensorDictModule) – value operator used to retrieve the value estimates.

  • average_gae (bool) – if True, the resulting GAE values will be standardized. Default is False.

  • differentiable (bool, optional) –

    if True, gradients are propagated through the computation of the value function. Default is False.

    Note

    The proper way to make the function call non-differentiable is to decorate it in a torch.no_grad() context manager/decorator or pass detached parameters for functional modules.

  • vectorized (bool, optional) – whether to use the vectorized version of the lambda return. Default is True.

  • skip_existing (bool, optional) – if True, the value network will skip modules which outputs are already present in the tensordict. Defaults to None, ie. the value of tensordict.nn.skip_existing() is not affected. Defaults to “state_value”.

  • advantage_key (str or tuple of str, optional) – [Deprecated] the key of the advantage entry. Defaults to "advantage".

  • value_target_key (str or tuple of str, optional) – [Deprecated] the key of the advantage entry. Defaults to "value_target".

  • value_key (str or tuple of str, optional) – [Deprecated] the value key to read from the input tensordict. Defaults to "state_value".

GAE will return an "advantage" entry containing the advange value. It will also return a "value_target" entry with the return value that is to be used to train the value network. Finally, if gradient_mode is True, an additional and differentiable "value_error" entry will be returned, which simple represents the difference between the return and the value network output (i.e. an additional distance loss should be applied to that signed value).

Note

As other advantage functions do, if the value_key is already present in the input tensordict, the GAE module will ignore the calls to the value network (if any) and use the provided value instead.

forward(tensordict: TensorDictBase, *unused_args, params: List[Tensor] | None = None, target_params: List[Tensor] | None = None) TensorDictBase[source]

Computes the GAE given the data in tensordict.

If a functional module is provided, a nested TensorDict containing the parameters (and if relevant the target parameters) can be passed to the module.

Parameters:
  • tensordict (TensorDictBase) – A TensorDict containing the data (an observation key, “action”, “reward”, “done” and “next” tensordict state as returned by the environment) necessary to compute the value estimates and the GAE. The data passed to this module should be structured as [*B, T, F] where B are the batch size, T the time dimension and F the feature dimension(s).

  • params (TensorDictBase, optional) – A nested TensorDict containing the params to be passed to the functional value network module.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

Returns:

An updated TensorDict with an advantage and a value_error keys as defined in the constructor.

Examples

>>> from tensordict import TensorDict
>>> value_net = TensorDictModule(
...     nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
... )
>>> module = GAE(
...     gamma=0.98,
...     lmbda=0.94,
...     value_network=value_net,
...     differentiable=False,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10])
>>> _ = module(tensordict)
>>> assert "advantage" in tensordict.keys()

The module supports non-tensordict (i.e. unpacked tensordict) inputs too:

Examples

>>> value_net = TensorDictModule(
...     nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
... )
>>> module = GAE(
...     gamma=0.98,
...     lmbda=0.94,
...     value_network=value_net,
...     differentiable=False,
... )
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
>>> reward = torch.randn(1, 10, 1)
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
>>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs)
value_estimate(tensordict, params: TensorDictBase | None = None, target_params: TensorDictBase | None = None, **kwargs)[source]

Gets a value estimate, usually used as a target value for the value network.

If the state value key is present under tensordict.get(("next", self.tensor_keys.value)) then this value will be used without recurring to the value network.

Parameters:
  • tensordict (TensorDictBase) – the tensordict containing the data to read.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

  • **kwargs – the keyword arguments to be passed to the value network.

Returns: a tensor corresponding to the state value.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources