Shortcuts

BurnInTransform

class torchrl.envs.transforms.BurnInTransform(modules: Sequence[TensorDictModuleBase], burn_in: int, out_keys: Sequence[NestedKey] | None = None)[source]

Transform to partially burn-in data sequences.

This transform is useful to obtain up-to-date recurrent states when they are not available. It burns-in a number of steps along the time dimension from sampled sequential data slices and returs the remaining data sequence with the burnt-in data in its initial time step. This transform is intended to be used as a replay buffer transform, not as an environment transform.

Parameters:
  • modules (sequence of TensorDictModule) – A list of modules used to burn-in data sequences.

  • burn_in (int) – The number of time steps to burn in.

  • out_keys (sequence of NestedKey, optional) – destination keys. Defaults to

  • ` (all the modules out_keys that point to the next time step (e.g. "hidden" if) –

  • ("next"

  • module). ("hidden")` is part of the out_keys of a) –

Note

This transform expects as inputs TensorDicts with its last dimension being the time dimension. It also assumes that all provided modules can process sequential data.

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.envs.transforms import BurnInTransform
>>> from torchrl.modules import GRUModule
>>> gru_module = GRUModule(
...     input_size=10,
...     hidden_size=10,
...     in_keys=["observation", "hidden"],
...     out_keys=["intermediate", ("next", "hidden")],
... ).set_recurrent_mode(True)
>>> burn_in_transform = BurnInTransform(
...     modules=[gru_module],
...     burn_in=5,
... )
>>> td = TensorDict({
...     "observation": torch.randn(2, 10, 10),
...      "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10),
...      "is_init": torch.zeros(2, 10, 1),
... }, batch_size=[2, 10])
>>> td = burn_in_transform(td)
>>> td.shape
torch.Size([2, 5])
>>> td.get("hidden").abs().sum()
tensor(86.3008)
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
>>> buffer = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(2),
...     batch_size=1,
... )
>>> buffer.append_transform(burn_in_transform)
>>> td = TensorDict({
...     "observation": torch.randn(2, 10, 10),
...      "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10),
...      "is_init": torch.zeros(2, 10, 1),
... }, batch_size=[2, 10])
>>> buffer.extend(td)
>>> td = buffer.sample(1)
>>> td.shape
torch.Size([1, 5])
>>> td.get("hidden").abs().sum()
tensor(37.0344)
forward(tensordict: TensorDictBase) TensorDictBase[source]

Reads the input tensordict, and for the selected keys, applies the transform.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources