CatFrames¶
- class torchrl.envs.transforms.CatFrames(N: int, dim: int, in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, padding='same', padding_value=0, as_inverse=False, reset_key: Optional[NestedKey] = None, done_key: Optional[NestedKey] = None)[source]¶
Concatenates successive observation frames into a single tensor.
This transform is useful for creating a sense of movement or velocity in the observed features. It can also be used with models that require access to past observations such as transformers and the like. It was first proposed in “Playing Atari with Deep Reinforcement Learning” (https://arxiv.org/pdf/1312.5602.pdf).
When used within a transformed environment,
CatFrames
is a stateful class, and it can be reset to its native state by calling thereset()
method. This method accepts tensordicts with a"_reset"
entry that indicates which buffer to reset.- Parameters:
N (int) – number of observation to concatenate.
dim (int) – dimension along which concatenate the observations. Should be negative, to ensure that it is compatible with environments of different batch_size.
in_keys (sequence of NestedKey, optional) – keys pointing to the frames that have to be concatenated. Defaults to [“pixels”].
out_keys (sequence of NestedKey, optional) – keys pointing to where the output has to be written. Defaults to the value of in_keys.
padding (str, optional) – the padding method. One of
"same"
or"constant"
. Defaults to"same"
, ie. the first value is used for padding.padding_value (
float
, optional) – the value to use for padding ifpadding="constant"
. Defaults to 0.as_inverse (bool, optional) – if
True
, the transform is applied as an inverse transform. Defaults toFalse
.reset_key (NestedKey, optional) – the reset key to be used as partial reset indicator. Must be unique. If not provided, defaults to the only reset key of the parent environment (if it has only one) and raises an exception otherwise.
done_key (NestedKey, optional) – the done key to be used as partial done indicator. Must be unique. If not provided, defaults to
"done"
.
Examples
>>> from torchrl.envs.libs.gym import GymEnv >>> env = TransformedEnv(GymEnv('Pendulum-v1'), ... Compose( ... UnsqueezeTransform(-1, in_keys=["observation"]), ... CatFrames(N=4, dim=-1, in_keys=["observation"]), ... ) ... ) >>> print(env.rollout(3))
The
CatFrames
transform can also be used offline to reproduce the effect of the online frame concatenation at a different scale (or for the purpose of limiting the memory consumption). The followin example gives the complete picture, together with the usage of atorchrl.data.ReplayBuffer
:Examples
>>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.envs import UnsqueezeTransform, CatFrames >>> from torchrl.collectors import SyncDataCollector >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension >>> env = TransformedEnv( ... GymEnv("CartPole-v1", from_pixels=True), ... Compose( ... ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]), ... Resize(in_keys=["pixels_trsf"], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf"]), ... UnsqueezeTransform(-4, in_keys=["pixels_trsf"]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]), ... ) ... ) >>> # we design a collector >>> collector = SyncDataCollector( ... env, ... RandomPolicy(env.action_spec), ... frames_per_batch=10, ... total_frames=1000, ... ) >>> for data in collector: ... print(data) ... break >>> # now let's create a transform for the replay buffer. We don't need to unsqueeze the data here. >>> # however, we need to point to both the pixel entry at the root and at the next levels: >>> t = Compose( ... ToTensorImage(in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... ) >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) >>> data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) >>> rb.add(data_exclude) >>> s = rb.sample(1) # the buffer has only one element >>> # let's check that our sample is the same as the batch collected during inference >>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all()
Note
CatFrames
currently only supports"done"
signal at the root. Nesteddone
, such as those found in MARL settings, are currently not supported. If this feature is needed, please raise an issue on TorchRL repo.Note
Storing stacks of frames in the replay buffer can significantly increase memory consumption (by N times). To mitigate this, you can store trajectories directly in the replay buffer and apply
CatFrames
at sampling time. This approach involves sampling slices of the stored trajectories and then applying the frame stacking transform. For convenience,CatFrames
provides amake_rb_transform_and_sampler()
method that creates:A modified version of the transform suitable for use in replay buffers
A corresponding
SliceSampler
to use with the buffer
- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
Reads the input tensordict, and for the selected keys, applies the transform.
- make_rb_transform_and_sampler(batch_size: int, **sampler_kwargs) Tuple[Transform, 'torchrl.data.replay_buffers.SliceSampler'] [source]¶
Creates a transform and sampler to be used with a replay buffer when storing frame-stacked data.
This method helps reduce redundancy in stored data by avoiding the need to store the entire stack of frames in the buffer. Instead, it creates a transform that stacks frames on-the-fly during sampling, and a sampler that ensures the correct sequence length is maintained.
- Parameters:
batch_size (int) – The batch size to use for the sampler.
**sampler_kwargs – Additional keyword arguments to pass to the
SliceSampler
constructor.
- Returns:
transform (Transform): A transform that stacks frames on-the-fly during sampling.
sampler (SliceSampler): A sampler that ensures the correct sequence length is maintained.
- Return type:
A tuple containing
Example
>>> env = TransformedEnv(...) >>> catframes = CatFrames(N=4, ...) >>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32) >>> rb = ReplayBuffer(..., sampler=sampler, transform=transform)
Note
When working with images, it’s recommended to use distinct
in_keys
andout_keys
in the precedingToTensorImage
transform. This ensures that the tensors stored in the buffer are separate from their processed counterparts, which we don’t want to store. For non-image data, consider inserting aRenameTransform
beforeCatFrames
to create a copy of the data that will be stored in the buffer.Note
When adding the transform to the replay buffer, one should pay attention to also pass the transforms that precede CatFrames, such as
ToTensorImage
orUnsqueezeTransform
in such a way that theCatFrames
transforms sees data formatted as it was during data collection.Note
For a more complete example, refer to torchrl’s github repo examples folder: https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform