CatTensors¶
- class torchrl.envs.transforms.CatTensors(in_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None, out_key: Union[str, Tuple[str, ...]] = 'observation_vector', dim: int = - 1, del_keys: bool = True, unsqueeze_if_oor: bool = False)[source]¶
Concatenates several keys in a single tensor.
This is especially useful if multiple keys describe a single state (e.g. “observation_position” and “observation_velocity”)
- Parameters:
in_keys (sequence of NestedKey) – keys to be concatenated. If None (or not provided) the keys will be retrieved from the parent environment the first time the transform is used. This behaviour will only work if a parent is set.
out_key (NestedKey) – key of the resulting tensor.
dim (int, optional) – dimension along which the concatenation will occur. Default is -1.
del_keys (bool, optional) – if
True
, the input values will be deleted after concatenation. Default is True.unsqueeze_if_oor (bool, optional) – if
True
, CatTensor will check that the dimension indicated exist for the tensors to concatenate. If not, the tensors will be unsqueezed along that dimension. Default is False.
Examples
>>> transform = CatTensors(in_keys=["key1", "key2"]) >>> td = TensorDict({"key1": torch.zeros(1, 1), ... "key2": torch.ones(1, 1)}, [1]) >>> _ = transform(td) >>> print(td.get("observation_vector")) tensor([[0., 1.]]) >>> transform = CatTensors(in_keys=["key1", "key2"], dim=-2, unsqueeze_if_oor=True) >>> td = TensorDict({"key1": torch.zeros(1), ... "key2": torch.ones(1)}, []) >>> _ = transform(td) >>> print(td.get("observation_vector").shape) torch.Size([2, 1])
- forward(tensordict: TensorDictBase) TensorDictBase ¶
Reads the input tensordict, and for the selected keys, applies the transform.
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform