TargetReturn¶
- class torchrl.envs.transforms.TargetReturn(target_return: float, mode: str = 'reduce', in_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None, out_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None)[source]¶
Sets a target return for the agent to achieve in the environment.
In goal-conditioned RL, the
TargetReturn
is defined as the expected cumulative reward obtained from the current state to the goal state or the end of the episode. It is used as input for the policy to guide its behaviour. For a trained policy typically the maximum return in the environment is chosen as the target return. However, as it is used as input to the policy module, it should be scaled accordingly. With theTargetReturn
transform, the tensordict can be updated to include the user-specified target return. The mode parameter can be used to specify whether the target return gets updated at every step by subtracting the reward achieved at each step or remains constant.TargetReturn
should be only used during inference when interacting with the environment as the actual return received by the environment might be different from the target return. Therefore, to have the correct return labels for training the policy, theTargetReturn
transform should be used in conjunction with for example hindsight return relabeling like theReward2GoTransform
to update the return label for the actually achieved return.- Parameters:
target_return (float) – target return to be achieved by the agent.
mode (str) – mode to be used to update the target return. Can be either “reduce” or “constant”. Default: “reduce”.
Examples
>>> transform = TargetReturn(10.0, mode="reduce") >>> td = TensorDict({}, [10]) >>> td = transform.reset(td) >>> td["target_return"] tensor([[10.], [10.], [10.], [10.], [10.], [10.], [10.], [10.], [10.], [10.]]) >>> # take a step with mode "reduce" >>> # target return is updated by subtracting the reward >>> reward = torch.ones((10,1)) >>> td.set(("next", "reward"), reward) >>> td = transform._step(td) >>> td["next", "target_return"] tensor([[9.], [9.], [9.], [9.], [9.], [9.], [9.], [9.], [9.], [9.]])
- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
Reads the input tensordict, and for the selected keys, applies the transform.
- reset(tensordict: TensorDict)[source]¶
Resets a tranform if it is stateful.
- transform_observation_spec(observation_spec: CompositeSpec) CompositeSpec [source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform