CatFrames¶
- class torchrl.envs.transforms.CatFrames(N: int, dim: int, in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, padding='same', padding_value=0, as_inverse=False, reset_key: Optional[NestedKey] = None, done_key: Optional[NestedKey] = None)[source]¶
Concatenates successive observation frames into a single tensor.
This can, for instance, account for movement/velocity of the observed feature. Proposed in “Playing Atari with Deep Reinforcement Learning” ( https://arxiv.org/pdf/1312.5602.pdf).
When used within a transformed environment,
CatFrames
is a stateful class, and it can be reset to its native state by calling thereset()
method. This method accepts tensordicts with a"_reset"
entry that indicates which buffer to reset.- Parameters:
N (int) – number of observation to concatenate.
dim (int) – dimension along which concatenate the observations. Should be negative, to ensure that it is compatible with environments of different batch_size.
in_keys (sequence of NestedKey, optional) – keys pointing to the frames that have to be concatenated. Defaults to [“pixels”].
out_keys (sequence of NestedKey, optional) – keys pointing to where the output has to be written. Defaults to the value of in_keys.
padding (str, optional) – the padding method. One of
"same"
or"constant"
. Defaults to"same"
, ie. the first value is used for padding.padding_value (
float
, optional) – the value to use for padding ifpadding="constant"
. Defaults to 0.as_inverse (bool, optional) – if
True
, the transform is applied as an inverse transform. Defaults toFalse
.reset_key (NestedKey, optional) – the reset key to be used as partial reset indicator. Must be unique. If not provided, defaults to the only reset key of the parent environment (if it has only one) and raises an exception otherwise.
done_key (NestedKey, optional) – the done key to be used as partial done indicator. Must be unique. If not provided, defaults to
"done"
.
Examples
>>> from torchrl.envs.libs.gym import GymEnv >>> env = TransformedEnv(GymEnv('Pendulum-v1'), ... Compose( ... UnsqueezeTransform(-1, in_keys=["observation"]), ... CatFrames(N=4, dim=-1, in_keys=["observation"]), ... ) ... ) >>> print(env.rollout(3))
The
CatFrames
transform can also be used offline to reproduce the effect of the online frame concatenation at a different scale (or for the purpose of limiting the memory consumption). The followin example gives the complete picture, together with the usage of atorchrl.data.ReplayBuffer
:Examples
>>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.envs import UnsqueezeTransform, CatFrames >>> from torchrl.collectors import SyncDataCollector >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension >>> env = TransformedEnv( ... GymEnv("CartPole-v1", from_pixels=True), ... Compose( ... ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]), ... Resize(in_keys=["pixels_trsf"], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf"]), ... UnsqueezeTransform(-4, in_keys=["pixels_trsf"]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]), ... ) ... ) >>> # we design a collector >>> collector = SyncDataCollector( ... env, ... RandomPolicy(env.action_spec), ... frames_per_batch=10, ... total_frames=1000, ... ) >>> for data in collector: ... print(data) ... break >>> # now let's create a transform for the replay buffer. We don't need to unsqueeze the data here. >>> # however, we need to point to both the pixel entry at the root and at the next levels: >>> t = Compose( ... ToTensorImage(in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... ) >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) >>> data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) >>> rb.add(data_exclude) >>> s = rb.sample(1) # the buffer has only one element >>> # let's check that our sample is the same as the batch collected during inference >>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all()
Note
CatFrames
currently only supports"done"
signal at the root. Nesteddone
, such as those found in MARL settings, are currently not supported. If this feature is needed, please raise an issue on TorchRL repo.- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
Reads the input tensordict, and for the selected keys, applies the transform.
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform