Shortcuts

LineariseRewards

class torchrl.envs.transforms.LineariseRewards(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey] | None = None, *, weights: Sequence[float] | Tensor | None = None)[source]

Transforms a multi-objective reward signal to a single-objective one via a weighted sum.

Parameters:
  • in_keys (List[NestedKey]) – The keys under which the multi-objective rewards are found.

  • out_keys (List[NestedKey], optional) – The keys under which single-objective rewards should be written. Defaults to in_keys.

  • weights (List[float], Tensor, optional) – Dictates how to weight each reward when summing them. Defaults to [1.0, 1.0, …].

Warning

If a sequence of in_keys of length strictly greater than one is passed (e.g. one group for each agent in a multi-agent set-up), the same weights will be applied for each entry. If you need to aggregate rewards differently for each group, use several LineariseRewards in a row.

Example

>>> import mo_gymnasium as mo_gym
>>> from torchrl.envs import MOGymWrapper
>>> mo_env = MOGymWrapper(mo_gym.make("deep-sea-treasure-v0"))
>>> mo_env.reward_spec
BoundedContinuous(
    shape=torch.Size([2]),
    space=ContinuousBox(
    low=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True),
    high=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True)),
    ...)
>>> so_env = TransformedEnv(mo_env, LineariseRewards(in_keys=("reward",)))
>>> so_env.reward_spec
BoundedContinuous(
    shape=torch.Size([1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
    ...)
>>> td = so_env.rollout(5)
>>> td["next", "reward"].shape
torch.Size([5, 1])
transform_reward_spec(reward_spec: TensorSpec) TensorSpec[source]

Transforms the reward spec such that the resulting spec matches transform mapping.

Parameters:

reward_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources