DTypeCastTransform¶
- class torchrl.envs.transforms.DTypeCastTransform(dtype_in: torch.dtype, dtype_out: torch.dtype, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None)[source]¶
Casts one dtype to another for selected keys.
Depending on whether the
in_keys
orin_keys_inv
are provided during construction, the class behaviour will change:If the keys are provided, those entries and those entries only will be transformed from
dtype_in
todtype_out
entries;If the keys are not provided and the object is within an environment register of transforms, the input and output specs that have a dtype set to
dtype_in
will be used as in_keys_inv / in_keys respectively.If the keys are not provided and the object is used without an environment, the
forward
/inverse
pass will scan through the input tensordict for alldtype_in
values and map them to adtype_out
tensor. For large data structures, this can impact performance as this scanning doesn’t come for free. The keys to be transformed will not be cached. Note that, in this case, the out_keys (resp. out_keys_inv) cannot be passed as the order on which the keys are processed cannot be anticipated precisely.
- Parameters:
dtype_in (torch.dtype) – the input dtype (from the env).
dtype_out (torch.dtype) – the output dtype (for model training).
in_keys (sequence of NestedKey, optional) – list of
dtype_in
keys to be converted todtype_out
before being exposed to external objects and functions.out_keys (sequence of NestedKey, optional) – list of destination keys. Defaults to
in_keys
if not provided.in_keys_inv (sequence of NestedKey, optional) – list of
dtype_out
keys to be converted todtype_in
before being passed to the contained base_env or storage.out_keys_inv (sequence of NestedKey, optional) – list of destination keys for inverse transform. Defaults to
in_keys_inv
if not provided.
Examples
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... 'not_transformed': torch.ones(1, dtype=torch.double), ... }, []) >>> transform = DTypeCastTransform(torch.double, torch.float, in_keys=["obs"]) >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 >>> print(td.get("not_transformed").dtype) torch.float64
In “automatic” mode, all float64 entries are transformed:
Examples
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... 'not_transformed': torch.ones(1, dtype=torch.double), ... }, []) >>> transform = DTypeCastTransform(torch.double, torch.float) >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 >>> print(td.get("not_transformed").dtype) torch.float32
The same behaviour is the rule when environments are constructedw without specifying the transform keys:
Examples
>>> class MyEnv(EnvBase): ... def __init__(self): ... super().__init__() ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64)) ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64) ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64) ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool) ... def _reset(self, data=None): ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) ... def _step(self, data): ... assert data["action"].dtype == torch.float64 ... reward = self.reward_spec.rand() ... done = torch.zeros((1,), dtype=torch.bool) ... obs = self.observation_spec.rand() ... assert reward.dtype == torch.float64 ... assert obs["obs"].dtype == torch.float64 ... return obs.empty().set("next", obs.update({"reward": reward, "done": done})) ... def _set_seed(self, seed): ... pass >>> env = TransformedEnv(MyEnv(), DTypeCastTransform(torch.double, torch.float)) >>> assert env.action_spec.dtype == torch.float32 >>> assert env.observation_spec["obs"].dtype == torch.float32 >>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype >>> print(env.rollout(2)) TensorDict( fields={ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) >>> assert env.transform.in_keys == ["obs", "reward"] >>> assert env.transform.in_keys_inv == ["action"]
- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
Reads the input tensordict, and for the selected keys, applies the transform.
- transform_input_spec(input_spec: TensorSpec) TensorSpec [source]¶
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_observation_spec(observation_spec)[source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_output_spec(output_spec: CompositeSpec) CompositeSpec [source]¶
Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
transform_observation_spec()
,transform_reward_spec()
andtransformfull_done_spec()
. :param output_spec: spec before the transform :type output_spec: TensorSpec- Returns:
expected spec after the transform