CatFrames¶
- class torchrl.envs.transforms.CatFrames(N: int, dim: int, in_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None, out_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None, padding='same', as_inverse=False)[source]¶
Concatenates successive observation frames into a single tensor.
This can, for instance, account for movement/velocity of the observed feature. Proposed in “Playing Atari with Deep Reinforcement Learning” ( https://arxiv.org/pdf/1312.5602.pdf).
When used within a transformed environment,
CatFrames
is a stateful class, and it can be reset to its native state by calling thereset()
method. This method accepts tensordicts with a"_reset"
entry that indicates which buffer to reset.- Parameters:
N (int) – number of observation to concatenate.
dim (int) – dimension along which concatenate the observations. Should be negative, to ensure that it is compatible with environments of different batch_size.
in_keys (sequence of NestedKey, optional) – keys pointing to the frames that have to be concatenated. Defaults to [“pixels”].
out_keys (sequence of NestedKey, optional) – keys pointing to where the output has to be written. Defaults to the value of in_keys.
padding (str, optional) – the padding method. One of
"same"
or"zeros"
. Defaults to"same"
, ie. the first value is uesd for padding.as_inverse (bool, optional) – if
True
, the transform is applied as an inverse transform. Defaults toFalse
.
Examples
>>> from torchrl.envs.libs.gym import GymEnv >>> env = TransformedEnv(GymEnv('Pendulum-v1'), ... Compose( ... UnsqueezeTransform(-1, in_keys=["observation"]), ... CatFrames(N=4, dim=-1, in_keys=["observation"]), ... ) ... ) >>> print(env.rollout(3))
The
CatFrames
transform can also be used offline to reproduce the effect of the online frame concatenation at a different scale (or for the purpose of limiting the memory consumption). The followin example gives the complete picture, together with the usage of atorchrl.data.ReplayBuffer
:Examples
>>> from torchrl.envs import UnsqueezeTransform, CatFrames >>> from torchrl.collectors import SyncDataCollector, RandomPolicy >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension >>> env = TransformedEnv( ... GymEnv("CartPole-v1", from_pixels=True), ... Compose( ... ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]), ... Resize(in_keys=["pixels_trsf"], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf"]), ... UnsqueezeTransform(-4, in_keys=["pixels_trsf"]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]), ... ) ... ) >>> # we design a collector >>> collector = SyncDataCollector( ... env, ... RandomPolicy(env.action_spec), ... frames_per_batch=10, ... total_frames=1000, ... ) >>> for data in collector: ... print(data) ... break >>> # now let's create a transform for the replay buffer. We don't need to unsqueeze the data here. >>> # however, we need to point to both the pixel entry at the root and at the next levels: >>> t = Compose( ... ToTensorImage(in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... ) >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) >>> data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) >>> rb.add(data_exclude) >>> s = rb.sample(1) # the buffer has only one element >>> # let's check that our sample is the same as the batch collected during inference >>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all()
- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
Reads the input tensordict, and for the selected keys, applies the transform.
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform