TDLambdaEstimator¶
- class torchrl.objectives.value.TDLambdaEstimator(*args, **kwargs)[source]¶
TD(\(\lambda\)) estimate of advantage function.
- Parameters:
gamma (scalar) – exponential mean discount.
lmbda (scalar) – trajectory discount.
value_network (TensorDictModule) – value operator used to retrieve the value estimates.
average_rewards (bool, optional) – if
True
, rewards will be standardized before the TD is computed.differentiable (bool, optional) –
if
True
, gradients are propagated through the computation of the value function. Default isFalse
.Note
The proper way to make the function call non-differentiable is to decorate it in a torch.no_grad() context manager/decorator or pass detached parameters for functional modules.
vectorized (bool, optional) – whether to use the vectorized version of the lambda return. Default is True.
skip_existing (bool, optional) – if
True
, the value network will skip modules which outputs are already present in the tensordict. Defaults toNone
, i.e., the value oftensordict.nn.skip_existing()
is not affected.advantage_key (str or tuple of str, optional) – [Deprecated] the key of the advantage entry. Defaults to
"advantage"
.value_target_key (str or tuple of str, optional) – [Deprecated] the key of the advantage entry. Defaults to
"value_target"
.value_key (str or tuple of str, optional) – [Deprecated] the value key to read from the input tensordict. Defaults to
"state_value"
.shifted (bool, optional) – if
True
, the value and next value are estimated with a single call to the value network. This is faster but is only valid whenever (1) the"next"
value is shifted by only one time step (which is not the case with multi-step value estimation, for instance) and (2) when the parameters used at timet
andt+1
are identical (which is not the case when target parameters are to be used). Defaults toFalse
.device (torch.device, optional) – device of the module.
time_dim (int, optional) – the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the
"time"
name if any, and to the last dimension otherwise. Can be overridden during a call tovalue_estimate()
. Negative dimensions are considered with respect to the input tensordict.
- forward(tensordict: TensorDictBase, *, params: List[Tensor] | None = None, target_params: List[Tensor] | None = None) TensorDictBase [source]¶
Computes the TD(\(\lambda\)) advantage given the data in tensordict.
If a functional module is provided, a nested TensorDict containing the parameters (and if relevant the target parameters) can be passed to the module.
- Parameters:
tensordict (TensorDictBase) – A TensorDict containing the data (an observation key,
"action"
,("next", "reward")
,("next", "done")
,("next", "terminated")
, and"next"
tensordict state as returned by the environment) necessary to compute the value estimates and the TDLambdaEstimate. The data passed to this module should be structured as[*B, T, *F]
whereB
are the batch size,T
the time dimension andF
the feature dimension(s). The tensordict must have shape[*B, T]
.- Keyword Arguments:
params (TensorDictBase, optional) – A nested TensorDict containing the params to be passed to the functional value network module.
target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.
- Returns:
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
Examples
>>> from tensordict import TensorDict >>> value_net = TensorDictModule( ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ... ) >>> module = TDLambdaEstimator( ... gamma=0.98, ... lmbda=0.94, ... value_network=value_net, ... ) >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated}}, [1, 10]) >>> _ = module(tensordict) >>> assert "advantage" in tensordict.keys()
The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
Examples
>>> value_net = TensorDictModule( ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ... ) >>> module = TDLambdaEstimator( ... gamma=0.98, ... lmbda=0.94, ... value_network=value_net, ... ) >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
- value_estimate(tensordict, target_params: TensorDictBase | None = None, next_value: torch.Tensor | None = None, time_dim: int | None = None, **kwargs)[source]¶
Gets a value estimate, usually used as a target value for the value network.
If the state value key is present under
tensordict.get(("next", self.tensor_keys.value))
then this value will be used without recurring to the value network.- Parameters:
tensordict (TensorDictBase) – the tensordict containing the data to read.
target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.
next_value (torch.Tensor, optional) – the value of the next state or state-action pair. Exclusive with
target_params
.**kwargs – the keyword arguments to be passed to the value network.
Returns: a tensor corresponding to the state value.