SelectTransform¶
- class torchrl.envs.transforms.SelectTransform(*selected_keys: NestedKey, keep_rewards: bool = True, keep_dones: bool = True)[source]¶
Select keys from the input tensordict.
- In general, the
ExcludeTransform
should be preferred: this transforms also selects the “action” (or other keys from input_spec), “done” and “reward” keys but other may be necessary.
- Parameters:
*selected_keys (iterable of NestedKey) – The name of the keys to select. If the key is not present, it is simply ignored.
- Keyword Arguments:
keep_rewards (bool, optional) – if
False
, the reward keys must be provided if they should be kept. Defaults toTrue
.keep_dones (bool, optional) – if
False
, the done keys must be provided if they should be kept. Defaults toTrue
.
Examples
>>> import gymnasium >>> from torchrl.envs import GymWrapper >>> env = TransformedEnv( ... GymWrapper(gymnasium.make("Pendulum-v1")), ... SelectTransform("observation", "reward", "done", keep_dones=False), # we leave done behind ... ) >>> env.rollout(3) # the truncated key is now absent TensorDict( fields={ action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False)
- forward(tensordict: TensorDictBase) TensorDictBase ¶
Reads the input tensordict, and for the selected keys, applies the transform.
- transform_output_spec(output_spec: Composite) Composite [source]¶
Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
transform_observation_spec()
,transform_reward_spec()
andtransformfull_done_spec()
. :param output_spec: spec before the transform :type output_spec: TensorSpec- Returns:
expected spec after the transform
- In general, the