SelectTransform¶
- class torchrl.envs.transforms.SelectTransform(*selected_keys: NestedKey, keep_rewards: bool = True, keep_dones: bool = True)[source]¶
Select keys from the input tensordict.
- In general, the
ExcludeTransform
should be preferred: this transforms also selects the “action” (or other keys from input_spec), “done” and “reward” keys but other may be necessary.
- Parameters:
*selected_keys (iterable of NestedKey) – The name of the keys to select. If the key is not present, it is simply ignored.
- Keyword Arguments:
keep_rewards (bool, optional) – if
False
, the reward keys must be provided if they should be kept. Defaults toTrue
.keep_dones (bool, optional) – if
False
, the done keys must be provided if they should be kept. Defaults toTrue
.gymnasium (>>> import) –
GymWrapper (...) –
TransformedEnv( (>>> env =) –
GymWrapper –
SelectTransform (...) –
) (...) –
env.rollout (>>>) –
TensorDict( –
- fields={
action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict(
- fields={
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]), device=cpu, is_shared=False),
observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]), device=cpu, is_shared=False)
- forward(tensordict: TensorDictBase) TensorDictBase ¶
Reads the input tensordict, and for the selected keys, applies the transform.
- transform_output_spec(output_spec: CompositeSpec) CompositeSpec [source]¶
Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
transform_observation_spec()
,transform_reward_spec()
andtransformfull_done_spec()
. :param output_spec: spec before the transform :type output_spec: TensorSpec- Returns:
expected spec after the transform
- In general, the