class torchrl.envs.transforms.ObservationNorm(loc: Optional[float, torch.Tensor] = None, scale: Optional[float, torch.Tensor] = None, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, standard_normal: bool = False)[source]

Observation affine transformation layer.

Normalizes an observation according to

\[obs = obs * scale + loc\]
  • loc (number or tensor) – location of the affine transform

  • scale (number or tensor) – scale of the affine transform

  • in_keys (sequence of NestedKey, optional) – entries to be normalized. Defaults to [“observation”, “pixels”]. All entries will be normalized with the same values: if a different behaviour is desired (e.g. a different normalization for pixels and states) different ObservationNorm objects should be used.

  • out_keys (sequence of NestedKey, optional) – output entries. Defaults to the value of in_keys.

  • in_keys_inv (sequence of NestedKey, optional) – ObservationNorm also supports inverse transforms. This will only occur if a list of keys is provided to in_keys_inv. If none is provided, only the forward transform will be called.

  • out_keys_inv (sequence of NestedKey, optional) – output entries for the inverse transform. Defaults to the value of in_keys_inv.

  • standard_normal (bool, optional) –

    if True, the transform will be

    \[obs = (obs-loc)/scale\]

    as it is done for standardization. Default is False.


>>> torch.set_default_tensor_type(torch.DoubleTensor)
>>> r = torch.randn(100, 3)*torch.randn(3) + torch.randn(3)
>>> td = TensorDict({'obs': r}, [100])
>>> transform = ObservationNorm(
...     loc = td.get('obs').mean(0),
...     scale = td.get('obs').std(0),
...     in_keys=["obs"],
...     standard_normal=True)
>>> _ = transform(td)
>>> print(torch.isclose(td.get('obs').mean(0),
...     torch.zeros(3)).all())
>>> print(torch.isclose(td.get('next_obs').std(0),
...     torch.ones(3)).all())

The normalization stats can be automatically computed: .. rubric:: Examples

>>> from torchrl.envs.libs.gym import GymEnv
>>> torch.manual_seed(0)
>>> env = GymEnv("Pendulum-v1")
>>> env = TransformedEnv(env, ObservationNorm(in_keys=["observation"]))
>>> env.set_seed(0)
>>> env.transform.init_stats(100)
>>> print(env.transform.loc, env.transform.scale)
tensor([-1.3752e+01, -6.5087e-03,  2.9294e-03], dtype=torch.float32) tensor([14.9636,  2.5608,  0.6408], dtype=torch.float32)
init_stats(num_iter: int, reduce_dim: Union[int, Tuple[int]] = 0, cat_dim: Optional[int] = None, key: Optional[Union[str, Tuple[str, ...]]] = None, keep_dims: Optional[Tuple[int]] = None) None[source]

Initializes the loc and scale stats of the parent environment.

Normalization constant should ideally make the observation statistics approach those of a standard Gaussian distribution. This method computes a location and scale tensor that will empirically compute the mean and standard deviation of a Gaussian distribution fitted on data generated randomly with the parent environment for a given number of steps.

  • num_iter (int) – number of random iterations to run in the environment.

  • reduce_dim (int or tuple of int, optional) – dimension to compute the mean and std over. Defaults to 0.

  • cat_dim (int, optional) – dimension along which the batches collected will be concatenated. It must be part equal to reduce_dim (if integer) or part of the reduce_dim tuple. Defaults to the same value as reduce_dim.

  • key (NestedKey, optional) – if provided, the summary statistics will be retrieved from that key in the resulting tensordicts. Otherwise, the first key in ObservationNorm.in_keys will be used.

  • keep_dims (tuple of int, optional) – the dimensions to keep in the loc and scale. For instance, one may want the location and scale to have shape [C, 1, 1] when normalizing a 3D tensor over the last two dimensions, but not the third. Defaults to None.


Transforms the input spec such that the resulting spec matches transform mapping.


input_spec (TensorSpec) – spec before the transform


expected spec after the transform

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

Transforms the observation spec such that the resulting spec matches transform mapping.


observation_spec (TensorSpec) – spec before the transform


expected spec after the transform


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources