UnaryTransform
- class torchrl.envs.transforms.UnaryTransform(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, *, fn: Callable[[Any], Tensor | TensorDictBase], inv_fn: Callable[[Any], Any] | None = None, use_raw_nontensor: bool = False)[source]
Applies a unary operation on the specified inputs.
- Parameters:
in_keys (sequence of NestedKey) – the keys of inputs to the unary operation.
out_keys (sequence of NestedKey) – the keys of the outputs of the unary operation.
in_keys_inv (sequence of NestedKey, optional) – the keys of inputs to the unary operation during inverse call.
out_keys_inv (sequence of NestedKey, optional) – the keys of the outputs of the unary operation durin inverse call.
- Keyword Arguments:
fn (Callable[[Any], Tensor | TensorDictBase]) – the function to use as the unary operation. If it accepts a non-tensor input, it must also accept
None
.inv_fn (Callable[[Any], Any], optional) – the function to use as the unary operation during inverse calls. If it accepts a non-tensor input, it must also accept
None
. Can be ommitted, in which casefn
will be used for inverse maps.use_raw_nontensor (bool, optional) – if
False
, data is extracted fromNonTensorData
/NonTensorStack
inputs beforefn
is called on them. IfTrue
, the rawNonTensorData
/NonTensorStack
inputs are given directly tofn
, which must support those inputs. Default isFalse
.
Example
>>> from torchrl.envs import GymEnv, UnaryTransform >>> env = GymEnv("Pendulum-v1") >>> env = env.append_transform( ... UnaryTransform( ... in_keys=["observation"], ... out_keys=["observation_trsf"], ... fn=lambda tensor: str(tensor.numpy().tobytes()))) >>> env.observation_spec Composite( observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), observation_trsf: NonTensor( shape=torch.Size([]), space=None, device=cpu, dtype=None, domain=None), device=None, shape=torch.Size([])) >>> env.rollout(3) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_trsf: NonTensorStack( ["b'\\xbe\\xbc\\x7f?8\\x859=/\\x81\\xbe;'", "b'\\x..., batch_size=torch.Size([3]), device=None), reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_trsf: NonTensorStack( ["b'\\x9a\\xbd\\x7f?\\xb8T8=8.c>'", "b'\\xbe\\xbc\..., batch_size=torch.Size([3]), device=None), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> env.check_env_specs() [torchrl][INFO] check_env_specs succeeded!
- transform_action_spec(action_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec [source]
Transforms the action spec such that the resulting spec matches transform mapping.
- Parameters:
action_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_done_spec(done_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec [source]
Transforms the done spec such that the resulting spec matches transform mapping.
- Parameters:
done_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_input_spec(input_spec: Composite) Composite [source]
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_observation_spec(observation_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec [source]
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_output_spec(output_spec: Composite) Composite [source]
Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
transform_observation_spec()
,transform_reward_spec()
andtransform_full_done_spec()
. :param output_spec: spec before the transform :type output_spec: TensorSpec- Returns:
expected spec after the transform
- transform_reward_spec(reward_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec [source]
Transforms the reward spec such that the resulting spec matches transform mapping.
- Parameters:
reward_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_state_spec(state_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec [source]
Transforms the state spec such that the resulting spec matches transform mapping.
- Parameters:
state_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform