class torchrl.envs.transforms.CatTensors(in_keys: Sequence[NestedKey] | None = None, out_key: NestedKey = 'observation_vector', dim: int = - 1, *, del_keys: bool = True, unsqueeze_if_oor: bool = False, sort: bool = True)[source]

Concatenates several keys in a single tensor.

This is especially useful if multiple keys describe a single state (e.g. “observation_position” and “observation_velocity”)

  • in_keys (sequence of NestedKey) – keys to be concatenated. If None (or not provided) the keys will be retrieved from the parent environment the first time the transform is used. This behaviour will only work if a parent is set.

  • out_key (NestedKey) – key of the resulting tensor.

  • dim (int, optional) – dimension along which the concatenation will occur. Default is -1.

Keyword Arguments:
  • del_keys (bool, optional) – if True, the input values will be deleted after concatenation. Default is True.

  • unsqueeze_if_oor (bool, optional) – if True, CatTensor will check that the dimension indicated exist for the tensors to concatenate. If not, the tensors will be unsqueezed along that dimension. Default is False.

  • sort (bool, optional) – if True, the keys will be sorted in the transform. Otherwise, the order provided by the user will prevail. Defaults to True.


>>> transform = CatTensors(in_keys=["key1", "key2"])
>>> td = TensorDict({"key1": torch.zeros(1, 1),
...     "key2": torch.ones(1, 1)}, [1])
>>> _ = transform(td)
>>> print(td.get("observation_vector"))
tensor([[0., 1.]])
>>> transform = CatTensors(in_keys=["key1", "key2"], dim=-2, unsqueeze_if_oor=True)
>>> td = TensorDict({"key1": torch.zeros(1),
...     "key2": torch.ones(1)}, [])
>>> _ = transform(td)
>>> print(td.get("observation_vector").shape)
torch.Size([2, 1])
forward(tensordict: TensorDictBase) TensorDictBase

Reads the input tensordict, and for the selected keys, applies the transform.

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

Transforms the observation spec such that the resulting spec matches transform mapping.


observation_spec (TensorSpec) – spec before the transform


expected spec after the transform


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources