Shortcuts

TargetReturn

class torchrl.envs.transforms.TargetReturn(target_return: float, mode: str = 'reduce', in_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None, out_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None)[source]

Sets a target return for the agent to achieve in the environment.

In goal-conditioned RL, the TargetReturn is defined as the expected cumulative reward obtained from the current state to the goal state or the end of the episode. It is used as input for the policy to guide its behaviour. For a trained policy typically the maximum return in the environment is chosen as the target return. However, as it is used as input to the policy module, it should be scaled accordingly. With the TargetReturn transform, the tensordict can be updated to include the user-specified target return. The mode parameter can be used to specify whether the target return gets updated at every step by subtracting the reward achieved at each step or remains constant. TargetReturn should be only used during inference when interacting with the environment as the actual return received by the environment might be different from the target return. Therefore, to have the correct return labels for training the policy, the TargetReturn transform should be used in conjunction with for example hindsight return relabeling like the Reward2GoTransform to update the return label for the actually achieved return.

Parameters:
  • target_return (float) – target return to be achieved by the agent.

  • mode (str) – mode to be used to update the target return. Can be either “reduce” or “constant”. Default: “reduce”.

Examples

>>> transform = TargetReturn(10.0, mode="reduce")
>>> td = TensorDict({}, [10])
>>> td = transform.reset(td)
>>> td["target_return"]
tensor([[10.],
        [10.],
        [10.],
        [10.],
        [10.],
        [10.],
        [10.],
        [10.],
        [10.],
        [10.]])
>>> # take a step with mode "reduce"
>>> # target return is updated by subtracting the reward
>>> reward = torch.ones((10,1))
>>> td.set(("next", "reward"), reward)
>>> td = transform._step(td)
>>> td["next", "target_return"]
tensor([[9.],
        [9.],
        [9.],
        [9.],
        [9.],
        [9.],
        [9.],
        [9.],
        [9.],
        [9.]])
forward(tensordict: TensorDictBase) TensorDictBase[source]

Reads the input tensordict, and for the selected keys, applies the transform.

reset(tensordict: TensorDict)[source]

Resets a tranform if it is stateful.

transform_observation_spec(observation_spec: CompositeSpec) CompositeSpec[source]

Transforms the observation spec such that the resulting spec matches transform mapping.

Parameters:

observation_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources