Source code for torchrl.objectives.value.functional
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import math
import warnings
from functools import wraps
from typing import Optional, Tuple, Union
import torch
try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling
__all__ = [
"generalized_advantage_estimate",
"vec_generalized_advantage_estimate",
"td0_advantage_estimate",
"td0_return_estimate",
"td1_return_estimate",
"vec_td1_return_estimate",
"td1_advantage_estimate",
"vec_td1_advantage_estimate",
"td_lambda_return_estimate",
"vec_td_lambda_return_estimate",
"td_lambda_advantage_estimate",
"vec_td_lambda_advantage_estimate",
"vtrace_advantage_estimate",
]
from torchrl.objectives.value.utils import (
_custom_conv1d,
_get_num_per_traj,
_inv_pad_sequence,
_make_gammas_tensor,
_split_and_pad_sequence,
)
SHAPE_ERR = (
"All input tensors (value, reward and done states) must share a unique shape."
)
def _transpose_time(fun):
"""Checks the time_dim argument of the function to allow for any dim.
If not -2, makes a transpose of all the multi-dim input tensors to bring
time at -2, and does the opposite transform for the outputs.
"""
ERROR = (
"The tensor shape and the time dimension are not compatible: "
"got {} and time_dim={}."
)
@wraps(fun)
def transposed_fun(*args, **kwargs):
time_dim = kwargs.pop("time_dim", -2)
def transpose_tensor(tensor):
if not isinstance(tensor, torch.Tensor) or tensor.numel() <= 1:
return tensor, False
if time_dim >= 0:
timedim = time_dim - tensor.ndim
else:
timedim = time_dim
if timedim < -tensor.ndim or timedim >= 0:
raise RuntimeError(ERROR.format(tensor.shape, timedim))
if tensor.ndim >= 2:
single_dim = False
tensor = tensor.transpose(timedim, -2)
elif tensor.ndim == 1 and timedim == -1:
single_dim = True
tensor = tensor.unsqueeze(-1)
else:
raise RuntimeError(ERROR.format(tensor.shape, timedim))
return tensor, single_dim
if time_dim != -2:
single_dim = False
if args:
args, single_dim = zip(*(transpose_tensor(arg) for arg in args))
single_dim = any(single_dim)
for k, item in list(kwargs.items()):
item, sd = transpose_tensor(item)
single_dim = single_dim or sd
kwargs[k] = item
# We don't pass time_dim because it isn't supposed to be used thereafter
out = fun(*args, **kwargs)
if isinstance(out, torch.Tensor):
out = transpose_tensor(out)[0]
if single_dim:
out = out.squeeze(-2)
return out
if single_dim:
return tuple(transpose_tensor(_out)[0].squeeze(-2) for _out in out)
return tuple(transpose_tensor(_out)[0] for _out in out)
# We don't pass time_dim because it isn't supposed to be used thereafter
out = fun(*args, **kwargs)
if isinstance(out, tuple):
for _out in out:
if _out.ndim < 2:
raise RuntimeError(ERROR.format(_out.shape, time_dim))
else:
if out.ndim < 2:
raise RuntimeError(ERROR.format(out.shape, time_dim))
return out
return transposed_fun
########################################################################
# GAE
# ---
[docs]@_transpose_time
def generalized_advantage_estimate(
gamma: float,
lmbda: float,
state_value: torch.Tensor,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
*,
time_dim: int = -2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generalized advantage estimate of a trajectory.
Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION"
https://arxiv.org/pdf/1506.02438.pdf for more context.
Args:
gamma (scalar): exponential mean discount.
lmbda (scalar): trajectory discount.
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
dtype = next_state_value.dtype
device = state_value.device
not_done = (~done).int()
not_terminated = (~terminated).int()
*batch_size, time_steps, lastdim = not_done.shape
advantage = torch.empty(
*batch_size, time_steps, lastdim, device=device, dtype=dtype
)
prev_advantage = 0
g_not_terminated = gamma * not_terminated
delta = reward + (g_not_terminated * next_state_value) - state_value
discount = lmbda * gamma * not_done
for t in reversed(range(time_steps)):
prev_advantage = advantage[..., t, :] = delta[..., t, :] + (
prev_advantage * discount[..., t, :]
)
value_target = advantage + state_value
return advantage, value_target
def _geom_series_like(t, r, thr):
"""Creates a geometric series of the form [1, gammalmbda, gammalmbda**2] with the shape of `t`.
Drops all elements which are smaller than `thr` (unless in compile mode).
"""
if is_dynamo_compiling():
if isinstance(r, torch.Tensor):
rs = r.expand_as(t)
else:
rs = torch.full_like(t, r)
else:
if isinstance(r, torch.Tensor):
r = r.item()
if r == 0.0:
return torch.zeros_like(t)
elif r >= 1.0:
lim = t.numel()
else:
lim = int(math.log(thr) / math.log(r))
rs = torch.full_like(t[:lim], r)
rs[0] = 1.0
rs = rs.cumprod(0)
rs = rs.unsqueeze(-1)
return rs
def _fast_vec_gae(
reward: torch.Tensor,
state_value: torch.Tensor,
next_state_value: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor,
gamma: float,
lmbda: float,
thr: float = 1e-7,
):
"""Fast vectorized Generalized Advantage Estimate when gamma and lmbda are scalars.
In contrast to `vec_generalized_advantage_estimate` this function does not need
to allocate a big tensor of the form [B, T, T].
Args:
reward (torch.Tensor): a [*B, T, F] tensor containing rewards
state_value (torch.Tensor): a [*B, T, F] tensor containing state values (value function)
next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function)
done (torch.Tensor): a [B, T] boolean tensor containing the done states.
terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states.
gamma (scalar): the gamma decay (trajectory discount)
lmbda (scalar): the lambda decay (exponential mean discount)
thr (:obj:`float`): threshold for the filter. Below this limit, components will ignored.
Defaults to 1e-7.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x F]``, with ``F`` feature dimensions.
"""
# _get_num_per_traj and _split_and_pad_sequence need
# time dimension at last position
done = done.transpose(-2, -1)
terminated = terminated.transpose(-2, -1)
reward = reward.transpose(-2, -1)
state_value = state_value.transpose(-2, -1)
next_state_value = next_state_value.transpose(-2, -1)
gammalmbda = gamma * lmbda
not_terminated = (~terminated).int()
td0 = reward + not_terminated * gamma * next_state_value - state_value
num_per_traj = _get_num_per_traj(done)
td0_flat, mask = _split_and_pad_sequence(td0, num_per_traj, return_mask=True)
gammalmbdas = _geom_series_like(td0_flat[0], gammalmbda, thr=thr)
advantage = _custom_conv1d(td0_flat.unsqueeze(1), gammalmbdas)
advantage = advantage.squeeze(1)
advantage = advantage[mask].view_as(reward)
value_target = advantage + state_value
advantage = advantage.transpose(-1, -2)
value_target = value_target.transpose(-1, -2)
return advantage, value_target
[docs]@_transpose_time
def vec_generalized_advantage_estimate(
gamma: Union[float, torch.Tensor],
lmbda: Union[float, torch.Tensor],
state_value: torch.Tensor,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
*,
time_dim: int = -2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Vectorized Generalized advantage estimate of a trajectory.
Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION"
https://arxiv.org/pdf/1506.02438.pdf for more context.
Args:
gamma (scalar): exponential mean discount.
lmbda (scalar): trajectory discount.
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
dtype = state_value.dtype
*batch_size, time_steps, lastdim = terminated.shape
value = gamma * lmbda
if isinstance(value, torch.Tensor) and value.numel() > 1:
# create tensor while ensuring that gradients are passed
not_done = (~done).to(dtype)
gammalmbdas = not_done * value
else:
# when gamma and lmbda are scalars, use fast_vec_gae implementation
return _fast_vec_gae(
reward=reward,
state_value=state_value,
next_state_value=next_state_value,
done=done,
terminated=terminated,
gamma=gamma,
lmbda=lmbda,
)
gammalmbdas = _make_gammas_tensor(gammalmbdas, time_steps, True)
gammalmbdas = gammalmbdas.cumprod(-2)
first_below_thr = gammalmbdas < 1e-7
# if we have multiple gammas, we only want to truncate if _all_ of
# the geometric sequences fall below the threshold
first_below_thr = first_below_thr.flatten(0, 1).all(0).all(-1)
if first_below_thr.any():
first_below_thr = torch.where(first_below_thr)[0][0].item()
gammalmbdas = gammalmbdas[..., :first_below_thr, :]
not_terminated = (~terminated).to(dtype)
td0 = reward + not_terminated * gamma * next_state_value - state_value
if len(batch_size) > 1:
td0 = td0.flatten(0, len(batch_size) - 1)
elif not len(batch_size):
td0 = td0.unsqueeze(0)
td0_r = td0.transpose(-2, -1)
shapes = td0_r.shape[:2]
if lastdim != 1:
# then we flatten again the first dims and reset a singleton in between
td0_r = td0_r.flatten(0, 1).unsqueeze(1)
advantage = _custom_conv1d(td0_r, gammalmbdas)
if lastdim != 1:
advantage = advantage.squeeze(1).unflatten(0, shapes)
if len(batch_size) > 1:
advantage = advantage.unflatten(0, batch_size)
elif not len(batch_size):
advantage = advantage.squeeze(0)
advantage = advantage.transpose(-2, -1)
value_target = advantage + state_value
return advantage, value_target
########################################################################
# TD(0)
# -----
[docs]def td0_advantage_estimate(
gamma: float,
state_value: torch.Tensor,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""TD(0) advantage estimate of a trajectory.
Also known as bootstrapped Temporal Difference or one-step return.
Args:
gamma (scalar): exponential mean discount.
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
returns = td0_return_estimate(gamma, next_state_value, reward, terminated)
advantage = returns - state_value
return advantage
[docs]def td0_return_estimate(
gamma: float,
next_state_value: torch.Tensor,
reward: torch.Tensor,
terminated: torch.Tensor | None = None,
*,
done: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# noqa: D417
"""TD(0) discounted return estimate of a trajectory.
Also known as bootstrapped Temporal Difference or one-step return.
Args:
gamma (scalar): exponential mean discount.
next_state_value (Tensor): value function result with new_state input.
must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor
reward (Tensor): reward of taking actions in the environment.
must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
Keyword Args:
done (Tensor): Deprecated. Use ``terminated`` instead.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if done is not None and terminated is None:
terminated = done.clone()
warnings.warn(
"done for td0_return_estimate is deprecated. Pass ``terminated`` instead."
)
if not (next_state_value.shape == reward.shape == terminated.shape):
raise RuntimeError(SHAPE_ERR)
not_terminated = (~terminated).int()
advantage = reward + gamma * not_terminated * next_state_value
return advantage
########################################################################
# TD(1)
# ----------
[docs]@_transpose_time
def td1_return_estimate(
gamma: float,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
*,
time_dim: int = -2,
) -> torch.Tensor:
r"""TD(1) return estimate.
Args:
gamma (scalar): exponential mean discount.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (next_state_value.shape == reward.shape == done.shape == terminated.shape):
raise RuntimeError(SHAPE_ERR)
not_done = (~done).int()
not_terminated = (~terminated).int()
returns = torch.empty_like(next_state_value)
T = returns.shape[-2]
single_gamma = False
if not (isinstance(gamma, torch.Tensor) and gamma.shape == not_done.shape):
single_gamma = True
gamma = torch.full_like(next_state_value, gamma)
if rolling_gamma is None:
rolling_gamma = True
elif not rolling_gamma and single_gamma:
raise RuntimeError(
"rolling_gamma=False is expected only with time-sensitive gamma values"
)
done_but_not_terminated = (done & ~terminated).int()
if rolling_gamma:
gamma = gamma * not_terminated
g = next_state_value[..., -1, :]
for i in reversed(range(T)):
# if not done (and hence not terminated), get the bootstrapped value
# if done but not terminated, get nex_val
# if terminated, take nothing (gamma = 0)
dnt = done_but_not_terminated[..., i, :]
g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * (
(1 - dnt) * g + dnt * next_state_value[..., i, :]
)
else:
for k in range(T):
g = 0
_gamma = gamma[..., k, :]
nd = not_terminated
_gamma = _gamma.unsqueeze(-2) * nd
for i in reversed(range(k, T)):
dnt = done_but_not_terminated[..., i, :]
g = reward[..., i, :] + _gamma[..., i, :] * (
(1 - dnt) * g + dnt * next_state_value[..., i, :]
)
returns[..., k, :] = g
return returns
[docs]def td1_advantage_estimate(
gamma: float,
state_value: torch.Tensor,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
time_dim: int = -2,
) -> torch.Tensor:
"""TD(1) advantage estimate.
Args:
gamma (scalar): exponential mean discount.
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
if not state_value.shape == next_state_value.shape:
raise RuntimeError("shape of state_value and next_state_value must match")
returns = td1_return_estimate(
gamma,
next_state_value,
reward,
done,
terminated=terminated,
rolling_gamma=rolling_gamma,
time_dim=time_dim,
)
advantage = returns - state_value
return advantage
[docs]@_transpose_time
def vec_td1_return_estimate(
gamma,
next_state_value,
reward,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: Optional[bool] = None,
time_dim: int = -2,
):
"""Vectorized TD(1) return estimate.
Args:
gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of the gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to ``-2``.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
return vec_td_lambda_return_estimate(
gamma=gamma,
next_state_value=next_state_value,
reward=reward,
done=done,
terminated=terminated,
rolling_gamma=rolling_gamma,
lmbda=1,
time_dim=time_dim,
)
[docs]def vec_td1_advantage_estimate(
gamma,
state_value,
next_state_value,
reward,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
time_dim: int = -2,
):
"""Vectorized TD(1) advantage estimate.
Args:
gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
return (
vec_td1_return_estimate(
gamma,
next_state_value,
reward,
done=done,
terminated=terminated,
rolling_gamma=rolling_gamma,
time_dim=time_dim,
)
- state_value
)
########################################################################
# TD(lambda)
# ----------
[docs]@_transpose_time
def td_lambda_return_estimate(
gamma: float,
lmbda: float,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
*,
time_dim: int = -2,
) -> torch.Tensor:
r"""TD(:math:`\lambda`) return estimate.
Args:
gamma (scalar): exponential mean discount.
lmbda (scalar): trajectory discount.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (next_state_value.shape == reward.shape == done.shape == terminated.shape):
raise RuntimeError(SHAPE_ERR)
not_terminated = (~terminated).int()
returns = torch.empty_like(next_state_value)
next_state_value = next_state_value * not_terminated
*batch, T, lastdim = returns.shape
# if gamma is not a tensor of the same shape as other inputs, we use rolling_gamma = True
single_gamma = False
if not (isinstance(gamma, torch.Tensor) and gamma.shape == done.shape):
single_gamma = True
gamma = torch.full_like(next_state_value, gamma)
single_lambda = False
if not (isinstance(lmbda, torch.Tensor) and lmbda.shape == done.shape):
single_lambda = True
lmbda = torch.full_like(next_state_value, lmbda)
if rolling_gamma is None:
rolling_gamma = True
elif not rolling_gamma and single_gamma and single_lambda:
raise RuntimeError(
"rolling_gamma=False is expected only with time-sensitive gamma or lambda values"
)
if rolling_gamma:
g = next_state_value[..., -1, :]
for i in reversed(range(T)):
dn = done[..., i, :].int()
nv = next_state_value[..., i, :]
lmd = lmbda[..., i, :]
# if done, the bootstrapped gain is the next value, otherwise it's the
# value we computed during the previous iter
g = g * (1 - dn) + nv * dn
g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * (
(1 - lmd) * nv + lmd * g
)
else:
for k in range(T):
g = next_state_value[..., -1, :]
_gamma = gamma[..., k, :]
_lambda = lmbda[..., k, :]
for i in reversed(range(k, T)):
dn = done[..., i, :].int()
nv = next_state_value[..., i, :]
g = g * (1 - dn) + nv * dn
g = reward[..., i, :] + _gamma * ((1 - _lambda) * nv + _lambda * g)
returns[..., k, :] = g
return returns
[docs]def td_lambda_advantage_estimate(
gamma: float,
lmbda: float,
state_value: torch.Tensor,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
# not a kwarg because used directly
time_dim: int = -2,
) -> torch.Tensor:
r"""TD(:math:`\lambda`) advantage estimate.
Args:
gamma (scalar): exponential mean discount.
lmbda (scalar): trajectory discount.
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
if not state_value.shape == next_state_value.shape:
raise RuntimeError("shape of state_value and next_state_value must match")
returns = td_lambda_return_estimate(
gamma,
lmbda,
next_state_value,
reward,
done,
terminated=terminated,
rolling_gamma=rolling_gamma,
time_dim=time_dim,
)
advantage = returns - state_value
return advantage
def _fast_td_lambda_return_estimate(
gamma: Union[torch.Tensor, float],
lmbda: float,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor,
thr: float = 1e-7,
):
"""Fast vectorized TD lambda return estimate.
In contrast to the generalized `vec_td_lambda_return_estimate` this function does not need
to allocate a big tensor of the form [B, T, T], but it only works with gamma/lmbda being scalars.
Args:
gamma (scalar): the gamma decay, can be a tensor with a single element (trajectory discount)
lmbda (scalar): the lambda decay (exponential mean discount)
next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function)
reward (torch.Tensor): a [*B, T, F] tensor containing rewards
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for end of episode.
thr (:obj:`float`): threshold for the filter. Below this limit, components will ignored.
Defaults to 1e-7.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x F]``, with ``F`` feature dimensions.
"""
device = reward.device
done = done.transpose(-2, -1)
terminated = terminated.transpose(-2, -1)
reward = reward.transpose(-2, -1)
next_state_value = next_state_value.transpose(-2, -1)
# the only valid next states are those where the trajectory does not terminate
next_state_value = (~terminated).int() * next_state_value
gamma_tensor = torch.tensor([gamma], device=device)
gammalmbda = gamma_tensor * lmbda
num_per_traj = _get_num_per_traj(done)
done = done.clone()
done[..., -1] = 1
not_done = (~done).int()
t = reward + next_state_value * gamma_tensor * (1 - not_done * lmbda)
t_flat, mask = _split_and_pad_sequence(t, num_per_traj, return_mask=True)
gammalmbdas = _geom_series_like(t_flat[0], gammalmbda, thr=thr)
ret_flat = _custom_conv1d(t_flat.unsqueeze(1), gammalmbdas)
ret = ret_flat.squeeze(1)[mask]
return ret.view_as(reward).transpose(-1, -2)
[docs]@_transpose_time
def vec_td_lambda_return_estimate(
gamma,
lmbda,
next_state_value,
reward,
done,
terminated: torch.Tensor | None = None,
rolling_gamma: Optional[bool] = None,
*,
time_dim: int = -2,
):
r"""Vectorized TD(:math:`\lambda`) return estimate.
Args:
gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
must be a [Batch x TimeSteps x 1] tensor.
lmbda (scalar): trajectory discount.
next_state_value (Tensor): value function result with new_state input.
must be a [Batch x TimeSteps x 1] tensor
reward (Tensor): reward of taking actions in the environment.
must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (next_state_value.shape == reward.shape == done.shape == terminated.shape):
raise RuntimeError(SHAPE_ERR)
gamma_thr = 1e-7
shape = next_state_value.shape
*batch, T, lastdim = shape
def _is_scalar(tensor):
return not isinstance(tensor, torch.Tensor) or tensor.numel() == 1
# There are two use-cases: if gamma/lmbda are scalars we can use the
# fast implementation, if not we must construct a gamma tensor.
if _is_scalar(gamma) and _is_scalar(lmbda):
return _fast_td_lambda_return_estimate(
gamma=gamma,
lmbda=lmbda,
next_state_value=next_state_value,
reward=reward,
done=done,
terminated=terminated,
thr=gamma_thr,
)
next_state_value = next_state_value.transpose(-2, -1).unsqueeze(-2)
if len(batch):
next_state_value = next_state_value.flatten(0, len(batch))
reward = reward.transpose(-2, -1).unsqueeze(-2)
if len(batch):
reward = reward.flatten(0, len(batch))
"""Vectorized version of td_lambda_advantage_estimate"""
device = reward.device
not_done = (~done).int()
not_terminated = (~terminated).int().transpose(-2, -1).unsqueeze(-2)
if len(batch):
not_terminated = not_terminated.flatten(0, len(batch))
next_state_value = next_state_value * not_terminated
if rolling_gamma is None:
rolling_gamma = True
if not rolling_gamma:
terminated_follows_terminated = terminated[..., 1:, :][
terminated[..., :-1, :]
].all()
if not terminated_follows_terminated:
raise NotImplementedError(
"When using rolling_gamma=False and vectorized TD(lambda) with time-dependent gamma, "
"make sure that conseducitve trajectories are separated as different batch "
"items. Propagating a gamma value across trajectories is not permitted with "
"this method. Check that you need to use rolling_gamma=False, and if so "
"consider using the non-vectorized version of the return computation or splitting "
"your trajectories."
)
if rolling_gamma:
# Make the coefficient table
gammas = _make_gammas_tensor(gamma * not_done, T, rolling_gamma)
gammas_cp = torch.cumprod(gammas, -2)
lambdas = torch.ones(T + 1, 1, device=device)
lambdas[1:] = lmbda
lambdas_cp = torch.cumprod(lambdas, -2)
lambdas = lambdas[1:]
dec = gammas_cp * lambdas_cp
gammas = _make_gammas_tensor(gamma, T, rolling_gamma)
gammas = gammas[..., 1:, :]
if gammas.ndimension() == 4 and gammas.shape[1] > 1:
gammas = gammas[:, :1]
if lambdas.ndimension() == 4 and lambdas.shape[1] > 1:
lambdas = lambdas[:, :1]
not_done = not_done.transpose(-2, -1).unsqueeze(-2)
if len(batch):
not_done = not_done.flatten(0, len(batch))
# lambdas = lambdas * not_done
v3 = (gammas * lambdas).squeeze(-1) * next_state_value * not_done
v3[..., :-1] = 0
out = _custom_conv1d(
reward
+ gammas.squeeze(-1)
* next_state_value
* (1 - lambdas.squeeze(-1) * not_done)
+ v3,
dec,
)
return out.view(*batch, lastdim, T).transpose(-2, -1)
else:
raise NotImplementedError(
"The vectorized version of TD(lambda) with rolling_gamma=False is currently not available. "
"To use this feature, use the non-vectorized version of TD(lambda). You can expect "
"good speed improvements by decorating the function with torch.compile!"
)
[docs]def vec_td_lambda_advantage_estimate(
gamma,
lmbda,
state_value,
next_state_value,
reward,
done,
terminated: torch.Tensor | None = None,
rolling_gamma: bool = None,
# not a kwarg because used directly
time_dim: int = -2,
):
r"""Vectorized TD(:math:`\lambda`) advantage estimate.
Args:
gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
lmbda (scalar): trajectory discount.
state_value (Tensor): value function result with old_state input.
next_state_value (Tensor): value function result with new_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of trajectory.
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
if not provided.
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
of a gamma tensor is tied to a single event:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
... v2 + g2 v3 + g2 g3 v4,
... v3 + g3 v4,
... v4,
... ]
if ``False``, it is assumed that each gamma is tied to the upcoming
trajectory:
>>> gamma = [g1, g2, g3, g4]
>>> value = [v1, v2, v3, v4]
>>> return = [
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
... v2 + g2 v3 + g2**2 v4,
... v3 + g3 v4,
... v4,
... ]
Default is ``True``.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if terminated is None:
terminated = done.clone()
if not (
next_state_value.shape
== state_value.shape
== reward.shape
== done.shape
== terminated.shape
):
raise RuntimeError(SHAPE_ERR)
return (
vec_td_lambda_return_estimate(
gamma,
lmbda,
next_state_value,
reward,
done=done,
terminated=terminated,
rolling_gamma=rolling_gamma,
time_dim=time_dim,
)
- state_value
)
########################################################################
# V-Trace
# -----
@_transpose_time
def vtrace_advantage_estimate(
gamma: float,
log_pi: torch.Tensor,
log_mu: torch.Tensor,
state_value: torch.Tensor,
next_state_value: torch.Tensor,
reward: torch.Tensor,
done: torch.Tensor,
terminated: torch.Tensor | None = None,
rho_thresh: Union[float, torch.Tensor] = 1.0,
c_thresh: Union[float, torch.Tensor] = 1.0,
# not a kwarg because used directly
time_dim: int = -2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes V-Trace off-policy actor critic targets.
Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures"
https://arxiv.org/abs/1802.01561 for more context.
Args:
gamma (scalar): exponential mean discount.
log_pi (Tensor): collection actor log probability of taking actions in the environment.
log_mu (Tensor): current actor log probability of taking actions in the environment.
state_value (Tensor): value function result with state input.
next_state_value (Tensor): value function result with next_state input.
reward (Tensor): reward of taking actions in the environment.
done (Tensor): boolean flag for end of episode.
terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states.
rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights.
c_thresh (Union[float, Tensor]): c clipping parameter for importance weights.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
All tensors (values, reward and done) must have shape
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
"""
if not (next_state_value.shape == state_value.shape == reward.shape == done.shape):
raise RuntimeError(SHAPE_ERR)
device = state_value.device
if not isinstance(rho_thresh, torch.Tensor):
rho_thresh = torch.tensor(rho_thresh, device=device)
if not isinstance(c_thresh, torch.Tensor):
c_thresh = torch.tensor(c_thresh, device=device)
c_thresh = c_thresh.to(device)
rho_thresh = rho_thresh.to(device)
not_done = (~done).int()
not_terminated = not_done if terminated is None else (~terminated).int()
*batch_size, time_steps, lastdim = not_done.shape
done_discounts = gamma * not_done
terminated_discounts = gamma * not_terminated
rho = (log_pi - log_mu).exp()
clipped_rho = rho.clamp_max(rho_thresh)
deltas = clipped_rho * (
reward + terminated_discounts * next_state_value - state_value
)
clipped_c = rho.clamp_max(c_thresh)
vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])]
for i in reversed(range(time_steps)):
discount_t, c_t, delta_t = (
done_discounts[..., i, :],
clipped_c[..., i, :],
deltas[..., i, :],
)
vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1])
vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:], dim=time_dim)
vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[time_dim])
vs = vs_minus_v_xs + state_value
vs_t_plus_1 = torch.cat(
[vs[..., 1:, :], next_state_value[..., -1:, :]], dim=time_dim
)
advantages = clipped_rho * (
reward + terminated_discounts * vs_t_plus_1 - state_value
)
return advantages, vs
########################################################################
# Reward to go
# ------------
[docs]@_transpose_time
def reward2go(
reward,
done,
gamma,
*,
time_dim: int = -2,
):
"""Compute the discounted cumulative sum of rewards given multiple trajectories and the episode ends.
Args:
reward (torch.Tensor): A tensor containing the rewards
received at each time step over multiple trajectories.
done (Tensor): boolean flag for end of episode. Differs from
truncated, where the episode did not end but was interrupted.
gamma (:obj:`float`, optional): The discount factor to use for computing the
discounted cumulative sum of rewards. Defaults to 1.0.
time_dim (int): dimension where the time is unrolled. Defaults to -2.
Returns:
torch.Tensor: A tensor of shape [B, T] containing the discounted cumulative
sum of rewards (reward-to-go) at each time step.
Examples:
>>> reward = torch.ones(1, 10)
>>> done = torch.zeros(1, 10, dtype=torch.bool)
>>> done[:, [3, 7]] = True
>>> reward2go(reward, done, 0.99, time_dim=-1)
tensor([[3.9404],
[2.9701],
[1.9900],
[1.0000],
[3.9404],
[2.9701],
[1.9900],
[1.0000],
[1.9900],
[1.0000]])
"""
shape = reward.shape
if shape != done.shape:
raise ValueError(
f"reward and done must share the same shape, got {reward.shape} and {done.shape}"
)
# flatten if needed
if reward.ndim > 2:
# we know time dim is at -2, let's put it at -3
rflip = reward.transpose(-2, -3)
rflip_shape = rflip.shape[-2:]
r2go = reward2go(
rflip.flatten(-2, -1), done.transpose(-2, -3).flatten(-2, -1), gamma=gamma
).unflatten(-1, rflip_shape)
return r2go.transpose(-2, -3)
# place time at dim -1
reward = reward.transpose(-2, -1)
done = done.transpose(-2, -1)
num_per_traj = _get_num_per_traj(done)
td0_flat = _split_and_pad_sequence(reward, num_per_traj)
gammas = _geom_series_like(td0_flat[0], gamma, thr=1e-7)
cumsum = _custom_conv1d(td0_flat.unsqueeze(1), gammas)
cumsum = cumsum.squeeze(1)
cumsum = _inv_pad_sequence(cumsum, num_per_traj)
cumsum = cumsum.reshape_as(reward)
cumsum = cumsum.transpose(-2, -1)
if cumsum.shape != shape:
try:
cumsum = cumsum.reshape(shape)
except RuntimeError:
raise RuntimeError(
f"Wrong shape for output reward2go: {cumsum.shape} when {shape} was expected."
)
return cumsum