DeviceCastTransform
- class torchrl.envs.transforms.DeviceCastTransform(device, orig_device=None, *, in_keys=None, out_keys=None, in_keys_inv=None, out_keys_inv=None)[source]
Moves data from one device to another.
- Parameters:
device (torch.device or equivalent) – the destination device.
orig_device (torch.device or equivalent) – the origin device. If not specified and a parent environment exists, it it retrieved from it. In all other cases, it remains unspecified.
- Keyword Arguments:
in_keys (list of NestedKey) – the list of entries to map to a different device. Defaults to
None
.out_keys (list of NestedKey) – the output names of the entries mapped onto a device. Defaults to the values of
in_keys
.in_keys_inv (list of NestedKey) – the list of entries to map to a different device.
in_keys_inv
are the names expected by the base environment. Defaults toNone
.out_keys_inv (list of NestedKey) – the output names of the entries mapped onto a device.
out_keys_inv
are the names of the keys as seen from outside the transformed env. Defaults to the values ofin_keys_inv
.
Examples
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... }, [], device="cpu:0") >>> transform = DeviceCastTransform(device=torch.device("cpu:2")) >>> td = transform(td) >>> print(td.device) cpu:2
- forward(tensordict: TensorDictBase = None) TensorDictBase [source]
Reads the input tensordict, and for the selected keys, applies the transform.
By default, this method:
calls directly
_apply_transform()
.does not call
_step()
or_call()
.
This method is not called within env.step at any point. However, is is called within
sample()
.Note
forward
also works with regular keyword arguments usingdispatch
to cast the args names to the keys.Examples
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_action_spec(full_action_spec: Composite) Composite [source]
Transforms the action spec such that the resulting spec matches transform mapping.
- Parameters:
action_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_done_spec(full_done_spec: Composite) Composite [source]
Transforms the done spec such that the resulting spec matches transform mapping.
- Parameters:
done_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_env_device(device)[source]
Transforms the device of the parent env.
- transform_input_spec(input_spec: Composite) Composite [source]
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_observation_spec(observation_spec: Composite) Composite [source]
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_output_spec(output_spec: Composite) Composite [source]
Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
transform_observation_spec()
,transform_reward_spec()
andtransform_full_done_spec()
. :param output_spec: spec before the transform :type output_spec: TensorSpec- Returns:
expected spec after the transform
- transform_reward_spec(full_reward_spec: Composite) Composite [source]
Transforms the reward spec such that the resulting spec matches transform mapping.
- Parameters:
reward_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_state_spec(full_state_spec: Composite) Composite [source]
Transforms the state spec such that the resulting spec matches transform mapping.
- Parameters:
state_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform