split_trajectories¶
- torchrl.collectors.utils.split_trajectories(rollout_tensordict: TensorDictBase, *, prefix=None, trajectory_key: tensordict._nestedkey.NestedKey | None = None, done_key: tensordict._nestedkey.NestedKey | None = None, as_nested: bool = False) TensorDictBase [source]¶
A util function for trajectory separation.
Takes a tensordict with a key traj_ids that indicates the id of each trajectory.
From there, builds a B x T x … zero-padded tensordict with B batches on max duration T
- Parameters:
rollout_tensordict (TensorDictBase) – a rollout with adjacent trajectories along the last dimension.
- Keyword Arguments:
prefix (NestedKey, optional) – the prefix used to read and write meta-data, such as
"traj_ids"
(the optional integer id of each trajectory) and the"mask"
entry indicating which data are valid and which aren’t. Defaults to"collector"
if the input has a"collector"
entry,()
(no prefix) otherwise.prefix
is kept as a legacy feature and will be deprecated eventually. Prefertrajectory_key
ordone_key
whenever possible.trajectory_key (NestedKey, optional) – the key pointing to the trajectory ids. Supersedes
done_key
andprefix
. If not provided, defaults to(prefix, "traj_ids")
.done_key (NestedKey, optional) – the key pointing to the
"done""
signal, if the trajectory could not be directly recovered. Defaults to"done"
.as_nested (bool or torch.layout, optional) –
whether to return the results as nested tensors. Defaults to
False
. If atorch.layout
is provided, it will be used to construct the nested tensor, otherwise the default layout will be used.Note
Using
split_trajectories(tensordict, as_nested=True).to_padded_tensor(mask=mask_key)
should result in the exact same result asas_nested=False
. Since this is an experimental feature and relies on nested_tensors, which API may change in the future, we made this an optional feature. The runtime should be faster withas_nested=True
.Note
Providing a layout lets the user control whether the nested tensor is to be used with
torch.strided
ortorch.jagged
layout. While the former has slightly more capabilities at the time of writing, the second will be the main focus of the PyTorch team in the future due to its better compatibility withcompile()
.
- Returns:
A new tensordict with a leading dimension corresponding to the trajectory. A
"mask"
boolean entry sharing thetrajectory_key
prefix and the tensordict shape is also added. It indicated the valid elements of the tensordict, as well as a"traj_ids"
entry iftrajectory_key
could not be found.
Examples
>>> from tensordict import TensorDict >>> import torch >>> from torchrl.collectors.utils import split_trajectories >>> obs = torch.cat([torch.arange(10), torch.arange(5)]) >>> obs_ = torch.cat([torch.arange(1, 11), torch.arange(1, 6)]) >>> done = torch.zeros(15, dtype=torch.bool) >>> done[9] = True >>> trajectory_id = torch.cat([torch.zeros(10, dtype=torch.int32), ... torch.ones(5, dtype=torch.int32)]) >>> data = TensorDict({"obs": obs, ("next", "obs"): obs_, ("next", "done"): done, "trajectory": trajectory_id}, batch_size=[15]) >>> data_split = split_trajectories(data, done_key="done") >>> print(data_split) TensorDict( fields={ mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), traj_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False) >>> # check that split_trajectories got the trajectories right with the done signal >>> assert (data_split["traj_ids"] == data_split["trajectory"]).all() >>> print(data_split["mask"]) tensor([[ True, True, True, True, True, True, True, True, True, True], [ True, True, True, True, True, False, False, False, False, False]]) >>> data_split = split_trajectories(data, trajectory_key="trajectory") >>> print(data_split) TensorDict( fields={ mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False)