Shortcuts

split_trajectories

torchrl.collectors.utils.split_trajectories(rollout_tensordict: TensorDictBase, *, prefix=None, trajectory_key: NestedKey | None = None, done_key: NestedKey | None = None) TensorDictBase[source]

A util function for trajectory separation.

Takes a tensordict with a key traj_ids that indicates the id of each trajectory.

From there, builds a B x T x … zero-padded tensordict with B batches on max duration T

Parameters:
  • rollout_tensordict (TensorDictBase) – a rollout with adjacent trajectories along the last dimension.

  • prefix (NestedKey, optional) – the prefix used to read and write meta-data, such as "traj_ids" (the optional integer id of each trajectory) and the "mask" entry indicating which data are valid and which aren’t. Defaults to "collector" if the input has a "collector" entry, () (no prefix) otherwise. prefix is kept as a legacy feature and will be deprecated eventually. Prefer trajectory_key or done_key whenever possible.

  • trajectory_key (NestedKey, optional) – the key pointing to the trajectory ids. Supersedes done_key and prefix. If not provided, defaults to (prefix, "traj_ids").

  • done_key (NestedKey, optional) – the key pointing to the "done"" signal, if the trajectory could not be directly recovered. Defaults to "done".

Returns:

A new tensordict with a leading dimension corresponding to the trajectory. A "mask" boolean entry sharing the trajectory_key prefix and the tensordict shape is also added. It indicated the valid elements of the tensordict, as well as a "traj_ids" entry if trajectory_key could not be found.

Examples

>>> from tensordict import TensorDict
>>> import torch
>>> from torchrl.collectors.utils import split_trajectories
>>> obs = torch.cat([torch.arange(10), torch.arange(5)])
>>> obs_ = torch.cat([torch.arange(1, 11), torch.arange(1, 6)])
>>> done = torch.zeros(15, dtype=torch.bool)
>>> done[9] = True
>>> trajectory_id = torch.cat([torch.zeros(10, dtype=torch.int32),
...     torch.ones(5, dtype=torch.int32)])
>>> data = TensorDict({"obs": obs, ("next", "obs"): obs_, ("next", "done"): done, "trajectory": trajectory_id}, batch_size=[15])
>>> data_split = split_trajectories(data, done_key="done")
>>> print(data_split)
TensorDict(
    fields={
        mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
                obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([2, 10]),
            device=None,
            is_shared=False),
        obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        traj_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)},
    batch_size=torch.Size([2, 10]),
    device=None,
    is_shared=False)
>>> # check that split_trajectory got the trajectories right with the done signal
>>> assert (data_split["traj_ids"] == data_split["trajectory"]).all()
>>> print(data_split["mask"])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True, False, False, False, False, False]])
>>> data_split = split_trajectories(data, trajectory_key="trajectory")
>>> print(data_split)
TensorDict(
    fields={
        mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
                obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([2, 10]),
            device=None,
            is_shared=False),
        obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)},
    batch_size=torch.Size([2, 10]),
    device=None,
    is_shared=False)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources