Shortcuts

SelectTransform

class torchrl.envs.transforms.SelectTransform(*selected_keys: NestedKey, keep_rewards: bool = True, keep_dones: bool = True)[source]

Select keys from the input tensordict.

In general, the ExcludeTransform should be preferred: this transforms also

selects the “action” (or other keys from input_spec), “done” and “reward” keys but other may be necessary.

Parameters:

*selected_keys (iterable of NestedKey) – The name of the keys to select. If the key is not present, it is simply ignored.

Keyword Arguments:
  • keep_rewards (bool, optional) – if False, the reward keys must be provided if they should be kept. Defaults to True.

  • keep_dones (bool, optional) – if False, the done keys must be provided if they should be kept. Defaults to True.

Examples

>>> import gymnasium
>>> from torchrl.envs import GymWrapper
>>> env = TransformedEnv(
...     GymWrapper(gymnasium.make("Pendulum-v1")),
...     SelectTransform("observation", "reward", "done", keep_dones=False), # we leave done behind
... )
>>> env.rollout(3)  # the truncated key is now absent
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
forward(tensordict: TensorDictBase) TensorDictBase

Reads the input tensordict, and for the selected keys, applies the transform.

transform_output_spec(output_spec: Composite) Composite[source]

Transforms the output spec such that the resulting spec matches transform mapping.

This method should generally be left untouched. Changes should be implemented using transform_observation_spec(), transform_reward_spec() and transformfull_done_spec(). :param output_spec: spec before the transform :type output_spec: TensorSpec

Returns:

expected spec after the transform

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources