Shortcuts

UnaryTransform

class torchrl.envs.transforms.UnaryTransform(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, *, fn: Callable[[Any], Tensor | TensorDictBase], inv_fn: Callable[[Any], Any] | None = None, use_raw_nontensor: bool = False)[source]

Applies a unary operation on the specified inputs.

Parameters:
  • in_keys (sequence of NestedKey) – the keys of inputs to the unary operation.

  • out_keys (sequence of NestedKey) – the keys of the outputs of the unary operation.

  • in_keys_inv (sequence of NestedKey, optional) – the keys of inputs to the unary operation during inverse call.

  • out_keys_inv (sequence of NestedKey, optional) – the keys of the outputs of the unary operation durin inverse call.

Keyword Arguments:
  • fn (Callable[[Any], Tensor | TensorDictBase]) – the function to use as the unary operation. If it accepts a non-tensor input, it must also accept None.

  • inv_fn (Callable[[Any], Any], optional) – the function to use as the unary operation during inverse calls. If it accepts a non-tensor input, it must also accept None. Can be ommitted, in which case fn will be used for inverse maps.

  • use_raw_nontensor (bool, optional) – if False, data is extracted from NonTensorData/NonTensorStack inputs before fn is called on them. If True, the raw NonTensorData/NonTensorStack inputs are given directly to fn, which must support those inputs. Default is False.

Example

>>> from torchrl.envs import GymEnv, UnaryTransform
>>> env = GymEnv("Pendulum-v1")
>>> env = env.append_transform(
...     UnaryTransform(
...         in_keys=["observation"],
...         out_keys=["observation_trsf"],
...             fn=lambda tensor: str(tensor.numpy().tobytes())))
>>> env.observation_spec
Composite(
    observation: BoundedContinuous(
        shape=torch.Size([3]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    observation_trsf: NonTensor(
        shape=torch.Size([]),
        space=None,
        device=cpu,
        dtype=None,
        domain=None),
    device=None,
    shape=torch.Size([]))
>>> env.rollout(3)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                observation_trsf: NonTensorStack(
                    ["b'\\xbe\\xbc\\x7f?8\\x859=/\\x81\\xbe;'", "b'\\x...,
                    batch_size=torch.Size([3]),
                    device=None),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        observation_trsf: NonTensorStack(
            ["b'\\x9a\\xbd\\x7f?\\xb8T8=8.c>'", "b'\\xbe\\xbc\...,
            batch_size=torch.Size([3]),
            device=None),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> env.check_env_specs()
[torchrl][INFO] check_env_specs succeeded!
transform_action_spec(action_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec[source]

Transforms the action spec such that the resulting spec matches transform mapping.

Parameters:

action_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

transform_done_spec(done_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[source]

Transforms the done spec such that the resulting spec matches transform mapping.

Parameters:

done_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

transform_input_spec(input_spec: Composite) Composite[source]

Transforms the input spec such that the resulting spec matches transform mapping.

Parameters:

input_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

transform_observation_spec(observation_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[source]

Transforms the observation spec such that the resulting spec matches transform mapping.

Parameters:

observation_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

transform_output_spec(output_spec: Composite) Composite[source]

Transforms the output spec such that the resulting spec matches transform mapping.

This method should generally be left untouched. Changes should be implemented using transform_observation_spec(), transform_reward_spec() and transform_full_done_spec(). :param output_spec: spec before the transform :type output_spec: TensorSpec

Returns:

expected spec after the transform

transform_reward_spec(reward_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[source]

Transforms the reward spec such that the resulting spec matches transform mapping.

Parameters:

reward_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

transform_state_spec(state_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec[source]

Transforms the state spec such that the resulting spec matches transform mapping.

Parameters:

state_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources