Shortcuts

Source code for torchrl.envs.transforms.rb_transforms

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import List

import torch

from tensordict import NestedKey, TensorDictBase
from torchrl.data.postprocs.postprocs import _multi_step_func
from torchrl.envs.transforms.transforms import Transform


[docs]class MultiStepTransform(Transform): """A MultiStep transformation for ReplayBuffers. This transform keeps the previous ``n_steps`` observations in a local buffer. The inverse transform (called during :meth:`~torchrl.data.ReplayBuffer.extend`) outputs the transformed previous ``n_steps`` with the ``T-n_steps`` current frames. All entries in the ``"next"`` tensordict that are not part of the ``done_keys`` or ``reward_keys`` will be mapped to their respective ``t + n_steps - 1`` correspondent. This transform is a more hyperparameter resistant version of :class:`~torchrl.data.postprocs.postprocs.MultiStep`: the replay buffer transform will make the multi-step transform insensitive to the collectors hyperparameters, whereas the post-process version will output results that are sensitive to these (because collectors have no memory of previous output). Args: n_steps (int): Number of steps in multi-step. The number of steps can be dynamically changed by changing the ``n_steps`` attribute of this transform. gamma (float): Discount factor. Keyword Args: reward_keys (list of NestedKey, optional): the reward keys in the input tensordict. The reward entries indicated by these keys will be accumulated and discounted across ``n_steps`` steps in the future. A corresponding ``<reward_key>_orig`` entry will be written in the ``"next"`` entry of the output tensordict to keep track of the original value of the reward. Defaults to ``["reward"]``. done_key (NestedKey, optional): the done key in the input tensordict, used to indicate an end of trajectory. Defaults to ``"done"``. done_keys (list of NestedKey, optional): the list of end keys in the input tensordict. All the entries indicated by these keys will be left untouched by the transform. Defaults to ``["done", "truncated", "terminated"]``. mask_key (NestedKey, optional): the mask key in the input tensordict. The mask represents the valid frames in the input tensordict and should have a shape that allows the input tensordict to be masked with. Defaults to ``"mask"``. Examples: >>> from torchrl.envs import GymEnv, TransformedEnv, StepCounter, MultiStepTransform, SerialEnv >>> from torchrl.data import ReplayBuffer, LazyTensorStorage >>> rb = ReplayBuffer( ... storage=LazyTensorStorage(100, ndim=2), ... transform=MultiStepTransform(n_steps=3, gamma=0.95) ... ) >>> base_env = SerialEnv(2, lambda: GymEnv("CartPole")) >>> env = TransformedEnv(base_env, StepCounter()) >>> _ = env.set_seed(0) >>> _ = torch.manual_seed(0) >>> tdreset = env.reset() >>> for _ in range(100): ... rollout = env.rollout(max_steps=50, break_when_any_done=False, ... tensordict=tdreset, auto_reset=False) ... indices = rb.extend(rollout) ... tdreset = rollout[..., -1]["next"] >>> print("step_count", rb[:]["step_count"][:, :5]) step_count tensor([[[ 9], [10], [11], [12], [13]], <BLANKLINE> [[12], [13], [14], [15], [16]]]) >>> # The next step_count is 3 steps in the future >>> print("next step_count", rb[:]["next", "step_count"][:, :5]) next step_count tensor([[[13], [14], [15], [16], [17]], <BLANKLINE> [[16], [17], [18], [19], [20]]]) """ ENV_ERR = ( "The MultiStepTransform is only an inverse transform and can " "be applied exclusively to replay buffers." ) def __init__( self, n_steps, gamma, *, reward_keys: List[NestedKey] | None = None, done_key: NestedKey | None = None, done_keys: List[NestedKey] | None = None, mask_key: NestedKey | None = None, ): super().__init__() self.n_steps = n_steps self.reward_keys = reward_keys self.done_key = done_key self.done_keys = done_keys self.mask_key = mask_key self.gamma = gamma self._buffer = None self._validated = False @property def n_steps(self): """The look ahead window of the transform. This value can be dynamically edited during training. """ return self._n_steps @n_steps.setter def n_steps(self, value): if not isinstance(value, int) or not (value >= 1): raise ValueError( "The value of n_steps must be a strictly positive integer." ) self._n_steps = value @property def done_key(self): return self._done_key @done_key.setter def done_key(self, value): if value is None: value = "done" self._done_key = value @property def done_keys(self): return self._done_keys @done_keys.setter def done_keys(self, value): if value is None: value = ["done", "terminated", "truncated"] self._done_keys = value @property def reward_keys(self): return self._reward_keys @reward_keys.setter def reward_keys(self, value): if value is None: value = [ "reward", ] self._reward_keys = value @property def mask_key(self): return self._mask_key @mask_key.setter def mask_key(self, value): if value is None: value = "mask" self._mask_key = value def _validate(self): if self.parent is not None: raise ValueError(self.ENV_ERR) self._validated = True def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: if not self._validated: self._validate() total_cat = self._append_tensordict(tensordict) if total_cat.shape[-1] >= self.n_steps: out = _multi_step_func( total_cat, done_key=self.done_key, done_keys=self.done_keys, reward_keys=self.reward_keys, mask_key=self.mask_key, n_steps=self.n_steps, gamma=self.gamma, ) return out[..., : -self.n_steps] def _append_tensordict(self, data): if self._buffer is None: total_cat = data self._buffer = data[..., -self.n_steps :].copy() else: total_cat = torch.cat([self._buffer, data], -1) self._buffer = total_cat[..., -self.n_steps :].copy() return total_cat

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources