Shortcuts

TargetReturn

class torchrl.envs.transforms.TargetReturn(target_return: float, mode: str = 'reduce', in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, reset_key: NestedKey | None = None)[source]

Sets a target return for the agent to achieve in the environment.

In goal-conditioned RL, the TargetReturn is defined as the expected cumulative reward obtained from the current state to the goal state or the end of the episode. It is used as input for the policy to guide its behaviour. For a trained policy typically the maximum return in the environment is chosen as the target return. However, as it is used as input to the policy module, it should be scaled accordingly. With the TargetReturn transform, the tensordict can be updated to include the user-specified target return. The mode parameter can be used to specify whether the target return gets updated at every step by subtracting the reward achieved at each step or remains constant.

Parameters:
  • target_return (float) – target return to be achieved by the agent.

  • mode (str) – mode to be used to update the target return. Can be either “reduce” or “constant”. Default: “reduce”.

  • in_keys (sequence of NestedKey, optional) – keys pointing to the reward entries. Defaults to the reward keys of the parent env.

  • out_keys (sequence of NestedKey, optional) – keys pointing to the target keys. Defaults to a copy of in_keys where the last element has been substituted by "target_return", and raises an exception if these keys aren’t unique.

  • reset_key (NestedKey, optional) – the reset key to be used as partial reset indicator. Must be unique. If not provided, defaults to the only reset key of the parent environment (if it has only one) and raises an exception otherwise.

Examples

>>> from torchrl.envs import GymEnv
>>> env = TransformedEnv(
...     GymEnv("CartPole-v1"),
...     TargetReturn(10.0, mode="reduce"))
>>> env.set_seed(0)
>>> torch.manual_seed(0)
>>> env.rollout(20)['target_return'].squeeze()
tensor([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.,  0., -1., -2., -3.])
forward(tensordict: TensorDictBase) TensorDictBase[source]

Reads the input tensordict, and for the selected keys, applies the transform.

transform_input_spec(input_spec: TensorSpec) TensorSpec[source]

Transforms the input spec such that the resulting spec matches transform mapping.

Parameters:

input_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

Transforms the observation spec such that the resulting spec matches transform mapping.

Parameters:

observation_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources