DoubleToFloat¶
- class torchrl.envs.transforms.DoubleToFloat(in_keys: Sequence[str] | None = None, in_keys_inv: Sequence[str] | None = None)[source]¶
Maps actions float to double before they are called on the environment.
- Parameters:
in_keys (list of str, optional) – list of double keys to be converted to float before being exposed to external objects and functions.
in_keys_inv (list of str, optional) – list of float keys to be converted to double before being passed to the contained base_env or storage.
Examples
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double)}, []) >>> transform = DoubleToFloat(in_keys=["obs"]) >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32
- transform_input_spec(input_spec: TensorSpec) TensorSpec [source]¶
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_observation_spec(observation_spec)[source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_reward_spec(observation_spec)[source]¶
Transforms the reward spec such that the resulting spec matches transform mapping.
- Parameters:
reward_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform