

class torchrl.envs.transforms.Compose(*transforms: Transform)[source]

Composes a chain of transforms.

Transform or ``callable``s are accepted.


>>> env = GymEnv("Pendulum-v0")
>>> transforms = [RewardScaling(1.0, 1.0), RewardClipping(-2.0, 2.0)]
>>> transforms = Compose(*transforms)
>>> transformed_env = TransformedEnv(env, transforms)
append(transform: Transform | Callable[[TensorDictBase], TensorDictBase]) None[source]

Appends a transform in the chain.

Transform or callable are accepted.

forward(tensordict: TensorDictBase) TensorDictBase[source]

Reads the input tensordict, and for the selected keys, applies the transform.

By default, this method:

  • calls directly _apply_transform().

  • does not call _step() or _call().

This method is not called within env.step at any point. However, is is called within sample().


forward also works with regular keyword arguments using dispatch to cast the args names to the keys.


>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.
init(tensordict: TensorDictBase) None[source]

Runs init steps for the transform.

insert(index: int, transform: Transform | Callable[[TensorDictBase], TensorDictBase]) None[source]

Inserts a transform in the chain at the desired index.

Transform or callable are accepted.

to(*args, **kwargs)[source]

Move and/or cast the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)[source]
to(dtype, non_blocking=False)[source]
to(tensor, non_blocking=False)[source]

Its signature is similar to, but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.


This method modifies the module in-place.

  • device (torch.device) – the desired device of the parameters and buffers in this module

  • dtype (torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module

  • tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module

  • memory_format (torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)



Return type:



>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>>, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
transform_action_spec(action_spec: TensorSpec) TensorSpec[source]

Transforms the action spec such that the resulting spec matches transform mapping.


action_spec (TensorSpec) – spec before the transform


expected spec after the transform

transform_env_batch_size(batch_size: torch.batch_size)[source]

Transforms the batch-size of the parent env.

transform_env_device(device: device)[source]

Transforms the device of the parent env.

transform_input_spec(input_spec: TensorSpec) TensorSpec[source]

Transforms the input spec such that the resulting spec matches transform mapping.


input_spec (TensorSpec) – spec before the transform


expected spec after the transform

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

Transforms the observation spec such that the resulting spec matches transform mapping.


observation_spec (TensorSpec) – spec before the transform


expected spec after the transform

transform_output_spec(output_spec: TensorSpec) TensorSpec[source]

Transforms the output spec such that the resulting spec matches transform mapping.

This method should generally be left untouched. Changes should be implemented using transform_observation_spec(), transform_reward_spec() and transform_full_done_spec(). :param output_spec: spec before the transform :type output_spec: TensorSpec


expected spec after the transform

transform_reward_spec(reward_spec: TensorSpec) TensorSpec[source]

Transforms the reward spec such that the resulting spec matches transform mapping.


reward_spec (TensorSpec) – spec before the transform


expected spec after the transform

transform_state_spec(state_spec: TensorSpec) TensorSpec[source]

Transforms the state spec such that the resulting spec matches transform mapping.


state_spec (TensorSpec) – spec before the transform


expected spec after the transform


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources