Shortcuts

FlattenObservation

class torchrl.envs.transforms.FlattenObservation(first_dim: int, last_dim: int, in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, allow_positive_dim: bool = False)[source]

Flatten adjacent dimensions of a tensor.

Parameters:
  • first_dim (int) – first dimension of the dimensions to flatten.

  • last_dim (int) – last dimension of the dimensions to flatten.

  • in_keys (sequence of NestedKey, optional) – the entries to flatten. If none is provided, ["pixels"] is assumed.

  • out_keys (sequence of NestedKey, optional) – the flatten observation keys. If none is provided, in_keys is assumed.

  • allow_positive_dim (bool, optional) – if True, positive dimensions are accepted. FlattenObservation will map these to the n^th feature dimension (ie n^th dimension after batch size of parent env) of the input tensor. Defaults to False, ie. non-negative dimensions are not permitted.

forward(tensordict: TensorDictBase) TensorDictBase

Reads the input tensordict, and for the selected keys, applies the transform.

For any operation that relates exclusively to the parent env (e.g. FrameSkip), modify the _step method instead. _call() should only be overwritten if a modification of the input tensordict is needed.

_call() will be called by TransformedEnv.step() and TransformedEnv.reset().

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

Transforms the observation spec such that the resulting spec matches transform mapping.

Parameters:

observation_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources