class torchrl.envs.transforms.TensorDictPrimer(primers: Optional[dict] = None, random=False, default_value=0.0, **kwargs)[source]

A primer for TensorDict initialization at reset time.

This transform will populate the tensordict at reset with values drawn from the relative tensorspecs provided at initialization. If the transform is used out of the env context (e.g. as an nn.Module or appended to a replay buffer), a call to forward will also populate the tensordict with the desired features.

  • primers (dict, optional) – a dictionary containing key-spec pairs which will be used to populate the input tensordict.

  • random (bool, optional) – if True, the values will be drawn randomly from the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. Defaults to False.

  • default_value (float, optional) – if non-random filling is chosen, this value will be used to populate the tensors. Defaults to 0.0.

  • **kwargs – each keyword argument corresponds to a key in the tensordict. The corresponding value has to be a TensorSpec instance indicating what the value must be.

When used in a TransfomedEnv, the spec shapes must match the envs shape if the parent env is batch-locked (env.batch_locked=True). If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is given by the input tensordict instead.


>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.envs import SerialEnv
>>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1"))
>>> env = TransformedEnv(base_env)
>>> # the env is batch-locked, so the leading dims of the spec must match those of the env
>>> env.append_transform(TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 3])))
>>> td = env.reset()
>>> print(td)
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        mykey: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
>>> # the entry is populated with 0s
>>> print(td.get("mykey"))
tensor([[0., 0., 0.],
        [0., 0., 0.]])

When calling env.step(), the current value of the key will be carried in the "next" tensordict __unless it already exists__.


>>> td = env.rand_step(td)
>>> print(td.get(("next", "mykey")))
tensor([[0., 0., 0.],
        [0., 0., 0.]])
>>> # with another value for "mykey", the previous value is not carried on
>>> td = env.reset()
>>> td = td.set(("next", "mykey"), torch.ones(2, 3))
>>> td = env.rand_step(td)
>>> print(td.get(("next", "mykey")))
tensor([[1., 1., 1.],
        [1., 1., 1.]])
forward(tensordict: TensorDictBase) TensorDictBase[source]

Reads the input tensordict, and for the selected keys, applies the transform.

reset(tensordict: TensorDictBase) TensorDictBase[source]

Sets the default values in the input tensordict.

If the parent is batch-locked, we assume that the specs have the appropriate leading shape. We allow for execution when the parent is missing, in which case the spec shape is assumed to match the tensordict’s.


Moves and/or casts the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)[source]
to(dtype, non_blocking=False)[source]
to(tensor, non_blocking=False)[source]

Its signature is similar to, but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.


This method modifies the module in-place.

  • device (torch.device) – the desired device of the parameters and buffers in this module

  • dtype (torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module

  • tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module

  • memory_format (torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)



Return type:



>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>>, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
transform_observation_spec(observation_spec: CompositeSpec) CompositeSpec[source]

Transforms the observation spec such that the resulting spec matches transform mapping.


observation_spec (TensorSpec) – spec before the transform


expected spec after the transform


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources