Shortcuts

DoubleToFloat

class torchrl.envs.transforms.DoubleToFloat(in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None)[source]

Casts one dtype to another for selected keys.

Depending on whether the in_keys or in_keys_inv are provided during construction, the class behaviour will change:

  • If the keys are provided, those entries and those entries only will be transformed from float64 to float32 entries;

  • If the keys are not provided and the object is within an environment register of transforms, the input and output specs that have a dtype set to float64 will be used as in_keys_inv / in_keys respectively.

  • If the keys are not provided and the object is used without an environment, the forward / inverse pass will scan through the input tensordict for all float64 values and map them to a float32 tensor. For large data structures, this can impact performance as this scanning doesn’t come for free. The keys to be transformed will not be cached. Note that, in this case, the out_keys (resp. out_keys_inv) cannot be passed as the order on which the keys are processed cannot be anticipated precisely.

Parameters:
  • in_keys (sequence of NestedKey, optional) – list of double keys to be converted to float before being exposed to external objects and functions.

  • out_keys (sequence of NestedKey, optional) – list of destination keys. Defaults to in_keys if not provided.

  • in_keys_inv (sequence of NestedKey, optional) – list of float keys to be converted to double before being passed to the contained base_env or storage.

  • out_keys_inv (sequence of NestedKey, optional) – list of destination keys for inverse transform. Defaults to in_keys_inv if not provided.

Examples

>>> td = TensorDict(
...     {'obs': torch.ones(1, dtype=torch.double),
...     'not_transformed': torch.ones(1, dtype=torch.double),
... }, [])
>>> transform = DoubleToFloat(in_keys=["obs"])
>>> _ = transform(td)
>>> print(td.get("obs").dtype)
torch.float32
>>> print(td.get("not_transformed").dtype)
torch.float64

In “automatic” mode, all float64 entries are transformed:

Examples

>>> td = TensorDict(
...     {'obs': torch.ones(1, dtype=torch.double),
...     'not_transformed': torch.ones(1, dtype=torch.double),
... }, [])
>>> transform = DoubleToFloat()
>>> _ = transform(td)
>>> print(td.get("obs").dtype)
torch.float32
>>> print(td.get("not_transformed").dtype)
torch.float32

The same behaviour is the rule when environments are constructedw without specifying the transform keys:

Examples

>>> class MyEnv(EnvBase):
...     def __init__(self):
...         super().__init__()
...         self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64))
...         self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64)
...         self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64)
...         self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool)
...     def _reset(self, data=None):
...         return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, [])
...     def _step(self, data):
...         assert data["action"].dtype == torch.float64
...         reward = self.reward_spec.rand()
...         done = torch.zeros((1,), dtype=torch.bool)
...         obs = self.observation_spec.rand()
...         assert reward.dtype == torch.float64
...         assert obs["obs"].dtype == torch.float64
...         return obs.empty().set("next", obs.update({"reward": reward, "done": done}))
...     def _set_seed(self, seed):
...         pass
>>> env = TransformedEnv(MyEnv(), DoubleToFloat())
>>> assert env.action_spec.dtype == torch.float32
>>> assert env.observation_spec["obs"].dtype == torch.float32
>>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype
>>> print(env.rollout(2))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
>>> assert env.transform.in_keys == ["obs", "reward"]
>>> assert env.transform.in_keys_inv == ["action"]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources