CatTensors¶
- class torchrl.envs.transforms.CatTensors(in_keys: Optional[Sequence[NestedKey]] = None, out_key: NestedKey = 'observation_vector', dim: int = - 1, *, del_keys: bool = True, unsqueeze_if_oor: bool = False, sort: bool = True)[source]¶
Concatenates several keys in a single tensor.
This is especially useful if multiple keys describe a single state (e.g. “observation_position” and “observation_velocity”)
- Parameters:
in_keys (sequence of NestedKey) – keys to be concatenated. If None (or not provided) the keys will be retrieved from the parent environment the first time the transform is used. This behavior will only work if a parent is set.
out_key (NestedKey) – key of the resulting tensor.
dim (int, optional) – dimension along which the concatenation will occur. Default is
-1
.
- Keyword Arguments:
del_keys (bool, optional) – if
True
, the input values will be deleted after concatenation. Default isTrue
.unsqueeze_if_oor (bool, optional) – if
True
, CatTensor will check that the dimension indicated exist for the tensors to concatenate. If not, the tensors will be unsqueezed along that dimension. Default isFalse
.sort (bool, optional) – if
True
, the keys will be sorted in the transform. Otherwise, the order provided by the user will prevail. Defaults toTrue
.
Examples
>>> transform = CatTensors(in_keys=["key1", "key2"]) >>> td = TensorDict({"key1": torch.zeros(1, 1), ... "key2": torch.ones(1, 1)}, [1]) >>> _ = transform(td) >>> print(td.get("observation_vector")) tensor([[0., 1.]]) >>> transform = CatTensors(in_keys=["key1", "key2"], dim=-2, unsqueeze_if_oor=True) >>> td = TensorDict({"key1": torch.zeros(1), ... "key2": torch.ones(1)}, []) >>> _ = transform(td) >>> print(td.get("observation_vector").shape) torch.Size([2, 1])
- forward(tensordict: TensorDictBase) TensorDictBase ¶
Reads the input tensordict, and for the selected keys, applies the transform.
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform