DeviceCastTransform¶
- class torchrl.envs.transforms.DeviceCastTransform(device, orig_device=None, *, in_keys=None, out_keys=None, in_keys_inv=None, out_keys_inv=None)[source]¶
Moves data from one device to another.
- Parameters:
device (torch.device or equivalent) – the destination device.
orig_device (torch.device or equivalent) – the origin device. If not specified and a parent environment exists, it it retrieved from it. In all other cases, it remains unspecified.
- Keyword Arguments:
in_keys (list of NestedKey) – the list of entries to map to a different device. Defaults to
None
.out_keys (list of NestedKey) – the output names of the entries mapped onto a device. Defaults to the values of
in_keys
.in_keys_inv (list of NestedKey) – the list of entries to map to a different device.
in_keys_inv
are the names expected by the base environment. Defaults toNone
.out_keys_inv (list of NestedKey) – the output names of the entries mapped onto a device.
out_keys_inv
are the names of the keys as seen from outside the transformed env. Defaults to the values ofin_keys_inv
.
Examples
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... }, [], device="cpu:0") >>> transform = DeviceCastTransform(device=torch.device("cpu:2")) >>> td = transform(td) >>> print(td.device) cpu:2
- forward(tensordict: TensorDictBase = None) TensorDictBase [source]¶
Reads the input tensordict, and for the selected keys, applies the transform.
- transform_done_spec(full_done_spec: Composite) Composite [source]¶
Transforms the done spec such that the resulting spec matches transform mapping.
- Parameters:
done_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_input_spec(input_spec: Composite) Composite [source]¶
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_observation_spec(observation_spec: Composite) Composite [source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_output_spec(output_spec: Composite) Composite [source]¶
Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
transform_observation_spec()
,transform_reward_spec()
andtransformfull_done_spec()
. :param output_spec: spec before the transform :type output_spec: TensorSpec- Returns:
expected spec after the transform
- transform_reward_spec(full_reward_spec: Composite) Composite [source]¶
Transforms the reward spec such that the resulting spec matches transform mapping.
- Parameters:
reward_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform