Shortcuts

CatFrames

class torchrl.envs.transforms.CatFrames(N: int, dim: int, in_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None, out_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None, padding='same', as_inverse=False)[source]

Concatenates successive observation frames into a single tensor.

This can, for instance, account for movement/velocity of the observed feature. Proposed in “Playing Atari with Deep Reinforcement Learning” ( https://arxiv.org/pdf/1312.5602.pdf).

When used within a transformed environment, CatFrames is a stateful class, and it can be reset to its native state by calling the reset() method. This method accepts tensordicts with a "_reset" entry that indicates which buffer to reset.

Parameters:
  • N (int) – number of observation to concatenate.

  • dim (int) – dimension along which concatenate the observations. Should be negative, to ensure that it is compatible with environments of different batch_size.

  • in_keys (sequence of NestedKey, optional) – keys pointing to the frames that have to be concatenated. Defaults to [“pixels”].

  • out_keys (sequence of NestedKey, optional) – keys pointing to where the output has to be written. Defaults to the value of in_keys.

  • padding (str, optional) – the padding method. One of "same" or "zeros". Defaults to "same", ie. the first value is uesd for padding.

  • as_inverse (bool, optional) – if True, the transform is applied as an inverse transform. Defaults to False.

Examples

>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv('Pendulum-v1'),
...     Compose(
...         UnsqueezeTransform(-1, in_keys=["observation"]),
...         CatFrames(N=4, dim=-1, in_keys=["observation"]),
...     )
... )
>>> print(env.rollout(3))

The CatFrames transform can also be used offline to reproduce the effect of the online frame concatenation at a different scale (or for the purpose of limiting the memory consumption). The followin example gives the complete picture, together with the usage of a torchrl.data.ReplayBuffer:

Examples

>>> from torchrl.envs import UnsqueezeTransform, CatFrames
>>> from torchrl.collectors import SyncDataCollector, RandomPolicy
>>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension
>>> env = TransformedEnv(
...     GymEnv("CartPole-v1", from_pixels=True),
...     Compose(
...         ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
...         Resize(in_keys=["pixels_trsf"], w=64, h=64),
...         GrayScale(in_keys=["pixels_trsf"]),
...         UnsqueezeTransform(-4, in_keys=["pixels_trsf"]),
...         CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]),
...     )
... )
>>> # we design a collector
>>> collector = SyncDataCollector(
...     env,
...     RandomPolicy(env.action_spec),
...     frames_per_batch=10,
...     total_frames=1000,
... )
>>> for data in collector:
...     print(data)
...     break
>>> # now let's create a transform for the replay buffer. We don't need to unsqueeze the data here.
>>> # however, we need to point to both the pixel entry at the root and at the next levels:
>>> t = Compose(
...         ToTensorImage(in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")]),
...         Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
...         GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
...         CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
... )
>>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
>>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16)
>>> data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf"))
>>> rb.add(data_exclude)
>>> s = rb.sample(1) # the buffer has only one element
>>> # let's check that our sample is the same as the batch collected during inference
>>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all()
forward(tensordict: TensorDictBase) TensorDictBase[source]

Reads the input tensordict, and for the selected keys, applies the transform.

reset(tensordict: TensorDictBase) TensorDictBase[source]

Resets _buffers.

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

Transforms the observation spec such that the resulting spec matches transform mapping.

Parameters:

observation_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources