# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import torch
from tensordict import NestedKey, TensorDictBase, unravel_key
from tensordict.nn import TensorDictModuleBase
from tensordict.utils import expand_right
from torch import nn
def _get_reward(
gamma: float,
reward: torch.Tensor,
done: torch.Tensor,
max_steps: int,
"""Sums the rewards up to max_steps in the future with a gamma decay.
Supports multiple consecutive trajectories.
Assumes that the time dimension is the *last* dim of reward and done.
filt = torch.tensor(
[gamma**i for i in range(max_steps + 1)],
).view(1, 1, -1)
# make one done mask per trajectory
done_cumsum = done.cumsum(-1)
done_cumsum = torch.cat(
[torch.zeros_like(done_cumsum[..., :1]), done_cumsum[..., :-1]], -1
num_traj = done_cumsum.max().item() + 1
done_cumsum = done_cumsum.expand(num_traj, *done.shape)
traj_ids = done_cumsum == torch.arange(
num_traj, device=done.device, dtype=done_cumsum.dtype
).view(num_traj, *[1 for _ in range(done_cumsum.ndim - 1)])
# an expanded reward tensor where each index along dim 0 is a different trajectory
# Note: rewards could have a different shape than done (e.g. multi-agent with a single
# done per group).
# we assume that reward has the same leading dimension as done.
if reward.shape != traj_ids.shape[1:]:
# We'll expand the ids on the right first
traj_ids_expand = expand_right(traj_ids, (num_traj, *reward.shape))
reward_traj = traj_ids_expand * reward
# we must make sure that the last dimension of the reward is the time
reward_traj = reward_traj.transpose(-1, traj_ids.ndim - 1)
# simpler use case: reward shape and traj_ids match
reward_traj = traj_ids * reward
reward_traj = torch.nn.functional.pad(reward_traj, [0, max_steps], value=0.0)
shape = reward_traj.shape[:-1]
if len(shape) > 1:
reward_traj = reward_traj.flatten(0, reward_traj.ndim - 2)
reward_traj = reward_traj.unsqueeze(-2)
summed_rewards = torch.conv1d(reward_traj, filt)
summed_rewards = summed_rewards.squeeze(-2)
if len(shape) > 1:
summed_rewards = summed_rewards.unflatten(0, shape)
# let's check that our summed rewards have the right size
if reward.shape != traj_ids.shape[1:]:
summed_rewards = summed_rewards.transpose(-1, traj_ids.ndim - 1)
summed_rewards = (summed_rewards * traj_ids_expand).sum(0)
summed_rewards = (summed_rewards * traj_ids).sum(0)
# time_to_obs is the tensor of the time delta to the next obs
# 0 = take the next obs (ie do nothing)
# 1 = take the obs after the next
time_to_obs = (
traj_ids.flip(-1).cumsum(-1).clamp_max(max_steps + 1).flip(-1) * traj_ids
time_to_obs = time_to_obs.sum(0)
time_to_obs = time_to_obs - 1
return summed_rewards, time_to_obs
[docs]class MultiStep(nn.Module):
"""Multistep reward transform.
Presented in
| Sutton, R. S. 1988. Learning to predict by the methods of temporal differences. Machine learning 3(1):9–44.
This module maps the "next" observation to the t + n "next" observation.
It is an identity transform whenever :attr:`n_steps` is 0.
gamma (:obj:`float`): Discount factor for return computation
n_steps (integer): maximum look-ahead steps.
.. note:: This class is meant to be used within a ``DataCollector``.
It will only treat the data passed to it at the end of a collection,
and ignore data preceding that collection or coming in the next batch.
As such, results on the last steps of the batch may likely be biased
by the early truncation of the trajectory.
To mitigate this effect, please use :class:`~torchrl.envs.transforms.MultiStepTransform`
within the replay buffer instead.
>>> from torchrl.collectors import SyncDataCollector, RandomPolicy
>>> from torchrl.data.postprocs import MultiStep
>>> from torchrl.envs import GymEnv, TransformedEnv, StepCounter
>>> env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
>>> env.set_seed(0)
>>> collector = SyncDataCollector(env, policy=RandomPolicy(env.action_spec),
... frames_per_batch=10, total_frames=2000, postproc=MultiStep(n_steps=4, gamma=0.99))
>>> for data in collector:
... break
>>> print(data["step_count"])
>>> # the next step count is shifted by 3 steps in the future
>>> print(data["next", "step_count"])
tensor([[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
def __init__(
gamma: float,
n_steps: int,
if n_steps <= 0:
raise ValueError("n_steps must be a non-negative integer.")
if not (gamma > 0 and gamma <= 1):
raise ValueError(f"got out-of-bounds gamma decay: gamma={gamma}")
self.gamma = gamma
self.n_steps = n_steps
[gamma**i for i in range(n_steps + 1)],
).reshape(1, 1, -1),
self.done_key = "done"
self.done_keys = ("done", "terminated", "truncated")
self.reward_keys = ("reward",)
self.mask_key = ("collector", "mask")
[docs] def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Re-writes a tensordict following the multi-step transform.
tensordict: :class:`tensordict.TensorDictBase` instance with
``[*Batch x Time-steps] shape.
The TensorDict must contain a ``("next", "reward")`` and
``("next", "done")`` keys.
All keys that are contained within the "next" nested tensordict
will be shifted by (at most) :attr:`~.n_steps` frames.
The TensorDict will also be updated with new key-value pairs:
- gamma: indicating the discount to be used for the next
- nonterminal: boolean value indicating whether a step is
non-terminal (not done or not last of trajectory);
- original_reward: previous reward collected in the
environment (i.e. before multi-step);
- The "reward" values will be replaced by the newly computed
The ``"done"`` key can have either the shape of the tensordict
OR the shape of the tensordict followed by a singleton
dimension OR the shape of the tensordict followed by other
dimensions. In the latter case, the tensordict *must* be
compatible with a reshape that follows the done shape (ie. the
leading dimensions of every tensor it contains must match the
shape of the ``"done"`` entry).
The ``"reward"`` tensor can have either the shape of the
tensordict (or done state) or this shape followed by a singleton
in-place transformation of the input tensordict.
return _multi_step_func(
def _multi_step_func(
# in accordance with common understanding of what n_steps should be
n_steps = n_steps - 1
tensordict = tensordict.clone(False)
done = tensordict.get(("next", done_key))
# we'll be using the done states to index the tensordict.
# if the shapes don't match we're in trouble.
ndim = tensordict.ndim
if done.shape != tensordict.shape:
if done.shape[-1] == 1 and done.shape[:-1] == tensordict.shape:
done = done.squeeze(-1)
# let's try to reshape the tensordict
tensordict.batch_size = done.shape
tensordict = tensordict.transpose(ndim - 1, tensordict.ndim - 1)
done = tensordict.get(("next", done_key))
except Exception as err:
raise RuntimeError(
"tensordict shape must be compatible with the done's shape "
"(trailing singleton dimension excluded)."
) from err
if mask_key is not None:
mask = tensordict.get(mask_key, None)
mask = None
*batch, T = tensordict.batch_size
summed_rewards = []
for reward_key in reward_keys:
reward = tensordict.get(("next", reward_key))
# sum rewards
summed_reward, time_to_obs = _get_reward(gamma, reward, done, n_steps)
idx_to_gather = torch.arange(
T, device=time_to_obs.device, dtype=time_to_obs.dtype
).expand(*batch, T)
idx_to_gather = idx_to_gather + time_to_obs
# idx_to_gather looks like tensor([[ 2, 3, 4, 5, 5, 5, 8, 9, 10, 10, 10]])
# with a done state tensor([[ 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]])
# meaning that the first obs will be replaced by the third, the second by the fourth etc.
# The fifth remains the fifth as it is terminal
tensordict_gather = (
.exclude(*reward_keys, *done_keys)
.gather(-1, idx_to_gather)
tensordict.set("steps_to_next_obs", time_to_obs + 1)
for reward_key, summed_reward in zip(reward_keys, summed_rewards):
tensordict.rename_key_(("next", reward_key), ("next", "original_reward"))
tensordict.set(("next", reward_key), summed_reward)
tensordict.set("gamma", gamma ** (time_to_obs + 1))
nonterminal = time_to_obs != 0
if mask is not None:
mask = mask.view(*batch, T)
nonterminal[~mask] = False
tensordict.set("nonterminal", nonterminal)
if tensordict.ndim != ndim:
tensordict = tensordict.apply(
lambda x: x.transpose(ndim - 1, tensordict.ndim - 1),
batch_size=done.transpose(ndim - 1, tensordict.ndim - 1).shape,
tensordict.batch_size = tensordict.batch_size[:ndim]
return tensordict
[docs]class DensifyReward(TensorDictModuleBase):
"""A util to reassign the reward at done state to the rest of the trajectory.
This transform is to be used with sparse rewards to assign a reward to each step of a trajectory when only the
reward at `done` is non-null.
.. note:: The class calls the :func:`~torchrl.objectives.value.functional.reward2go` function, which will
also sum intermediate rewards. Make sure you understand what the `reward2go` function returns before using
this module.
reward_key (NestedKey, optional): The key in the input TensorDict where the reward is stored.
Defaults to `"reward"`.
done_key (NestedKey, optional): The key in the input TensorDict where the done flag is stored.
Defaults to `"done"`.
reward_key_out (NestedKey | None, optional): The key in the output TensorDict where the reassigned reward
will be stored. If None, it defaults to the value of `reward_key`.
Defaults to `None`.
time_dim (int, optional): The dimension in the input TensorDict where the time is unrolled.
Defaults to `2`.
discount (float, optional): The discount factor to use for computing the discounted cumulative sum of rewards.
Defaults to `1.0` (no discounting).
TensorDict: The input TensorDict with the reassigned reward stored under the key specified by `reward_key_out`.
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import DensifyReward
>>> # Create a sample TensorDict
>>> tensordict = TensorDict({
... "next": {
... "reward": torch.zeros(10, 1),
... "done": torch.zeros(10, 1, dtype=torch.bool)
... }
... }, batch_size=[10])
>>> # Set some done flags and rewards
>>> tensordict["next", "done"][[3, 7]] = True
>>> tensordict["next", "reward"][3] = 3
>>> tensordict["next", "reward"][7] = 7
>>> # Create an instance of LastRewardToTraj
>>> last_reward_to_traj = DensifyReward()
>>> # Apply the transform
>>> new_tensordict = last_reward_to_traj(tensordict)
>>> # Print the reassigned rewards
>>> print(new_tensordict["next", "reward"])
def __init__(
reward_key: NestedKey = "reward",
done_key: NestedKey = "done",
reward_key_out: NestedKey | None = None,
time_dim: int = 2,
discount: float = 1.0,
from torchrl.objectives.value.functional import reward2go
self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
if reward_key_out is None:
reward_key_out = reward_key
self.out_keys = [unravel_key(reward_key_out)]
self.time_dim = time_dim
self.discount = discount
self.reward2go = reward2go
[docs] def forward(self, tensordict):
# Get done
done = tensordict.get(("next", self.in_keys[1]))
# Get reward
reward = tensordict.get(("next", self.in_keys[0]))
if reward.shape != done.shape:
raise RuntimeError(
f"reward and done state are expected to have the same shape. Got reard.shape={reward.shape} "
f"and done.shape={done.shape}."
reward = self.reward2go(reward, done, time_dim=-2, gamma=self.discount)
tensordict.set(("next", self.out_keys[0]), reward)
return tensordict