class torchrl.envs.transforms.FlattenObservation(first_dim: int, last_dim: int, in_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None, out_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None, allow_positive_dim: bool = False)[source]

Flatten adjacent dimensions of a tensor.

  • first_dim (int) – first dimension of the dimensions to flatten.

  • last_dim (int) – last dimension of the dimensions to flatten.

  • in_keys (sequence of NestedKey, optional) – the entries to flatten. If none is provided, ["pixels"] is assumed.

  • out_keys (sequence of NestedKey, optional) – the flatten observation keys. If none is provided, in_keys is assumed.

  • allow_positive_dim (bool, optional) – if True, positive dimensions are accepted. FlattenObservation will map these to the n^th feature dimension (ie n^th dimension after batch size of parent env) of the input tensor. Defaults to False, ie. non-negative dimensions are not permitted.

forward(tensordict: TensorDictBase) TensorDictBase

Reads the input tensordict, and for the selected keys, applies the transform.

For any operation that relates exclusively to the parent env (e.g. FrameSkip), modify the _step method instead. _call() should only be overwritten if a modification of the input tensordict is needed.

_call() will be called by TransformedEnv.step() and TransformedEnv.reset().

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

Transforms the observation spec such that the resulting spec matches transform mapping.


observation_spec (TensorSpec) – spec before the transform


expected spec after the transform


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources