Shortcuts

DoubleToFloat

class torchrl.envs.transforms.DoubleToFloat(in_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None, in_keys_inv: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None)[source]

Casts one dtype to another for selected keys.

Depending on whether the in_keys or in_keys_inv are provided during construction, the class behaviour will change:

  • If the keys are provided, those entries and those entries only will be transformed from float64 to float32 entries;

  • If the keys are not provided and the object is within an environment register of transforms, the input and output specs that have a dtype set to float64 will be used as in_keys_inv / in_keys respectively.

  • If the keys are not provided and the object is used without an environment, the forward / inverse pass will scan through the input tensordict for all float64 values and map them to a float32 tensor. For large data structures, this can impact performance as this scanning doesn’t come for free. The keys to be transformed will not be cached.

Parameters:
  • in_keys (sequence of NestedKey, optional) – list of double keys to be converted to float before being exposed to external objects and functions.

  • in_keys_inv (sequence of NestedKey, optional) – list of float keys to be converted to double before being passed to the contained base_env or storage.

Examples

>>> td = TensorDict(
...     {'obs': torch.ones(1, dtype=torch.double),
...     'not_transformed': torch.ones(1, dtype=torch.double),
... }, [])
>>> transform = DoubleToFloat(in_keys=["obs"])
>>> _ = transform(td)
>>> print(td.get("obs").dtype)
torch.float32
>>> print(td.get("not_transformed").dtype)
torch.float64

In “automatic” mode, all float64 entries are transformed:

Examples

>>> td = TensorDict(
...     {'obs': torch.ones(1, dtype=torch.double),
...     'not_transformed': torch.ones(1, dtype=torch.double),
... }, [])
>>> transform = DoubleToFloat()
>>> _ = transform(td)
>>> print(td.get("obs").dtype)
torch.float32
>>> print(td.get("not_transformed").dtype)
torch.float32

The same behaviour is the rule when environments are constructedw without specifying the transform keys:

Examples

>>> class MyEnv(EnvBase):
...     def __init__(self):
...         super().__init__()
...         self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64))
...         self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64)
...         self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64)
...         self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool)
...     def _reset(self, data=None):
...         return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, [])
...     def _step(self, data):
...         assert data["action"].dtype == torch.float64
...         reward = self.reward_spec.rand()
...         done = torch.zeros((1,), dtype=torch.bool)
...         obs = self.observation_spec.rand()
...         assert reward.dtype == torch.float64
...         assert obs["obs"].dtype == torch.float64
...         return obs.select().set("next", obs.update({"reward": reward, "done": done}))
...     def _set_seed(self, seed):
...         pass
>>> env = TransformedEnv(MyEnv(), DoubleToFloat())
>>> assert env.action_spec.dtype == torch.float32
>>> assert env.observation_spec["obs"].dtype == torch.float32
>>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype
>>> print(env.rollout(2))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
>>> assert env.transform.in_keys == ["obs", "reward"]
>>> assert env.transform.in_keys_inv == ["action"]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources