split_trajectories¶
- torchrl.collectors.utils.split_trajectories(rollout_tensordict: TensorDictBase, *, prefix=None, trajectory_key: NestedKey | None = None, done_key: NestedKey | None = None) TensorDictBase [source]¶
A util function for trajectory separation.
Takes a tensordict with a key traj_ids that indicates the id of each trajectory.
From there, builds a B x T x … zero-padded tensordict with B batches on max duration T
- Parameters:
rollout_tensordict (TensorDictBase) – a rollout with adjacent trajectories along the last dimension.
prefix (NestedKey, optional) – the prefix used to read and write meta-data, such as
"traj_ids"
(the optional integer id of each trajectory) and the"mask"
entry indicating which data are valid and which aren’t. Defaults to"collector"
if the input has a"collector"
entry,()
(no prefix) otherwise.prefix
is kept as a legacy feature and will be deprecated eventually. Prefertrajectory_key
ordone_key
whenever possible.trajectory_key (NestedKey, optional) – the key pointing to the trajectory ids. Supersedes
done_key
andprefix
. If not provided, defaults to(prefix, "traj_ids")
.done_key (NestedKey, optional) – the key pointing to the
"done""
signal, if the trajectory could not be directly recovered. Defaults to"done"
.
- Returns:
A new tensordict with a leading dimension corresponding to the trajectory. A
"mask"
boolean entry sharing thetrajectory_key
prefix and the tensordict shape is also added. It indicated the valid elements of the tensordict, as well as a"traj_ids"
entry iftrajectory_key
could not be found.
Examples
>>> from tensordict import TensorDict >>> import torch >>> from torchrl.collectors.utils import split_trajectories >>> obs = torch.cat([torch.arange(10), torch.arange(5)]) >>> obs_ = torch.cat([torch.arange(1, 11), torch.arange(1, 6)]) >>> done = torch.zeros(15, dtype=torch.bool) >>> done[9] = True >>> trajectory_id = torch.cat([torch.zeros(10, dtype=torch.int32), ... torch.ones(5, dtype=torch.int32)]) >>> data = TensorDict({"obs": obs, ("next", "obs"): obs_, ("next", "done"): done, "trajectory": trajectory_id}, batch_size=[15]) >>> data_split = split_trajectories(data, done_key="done") >>> print(data_split) TensorDict( fields={ mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), traj_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False) >>> # check that split_trajectory got the trajectories right with the done signal >>> assert (data_split["traj_ids"] == data_split["trajectory"]).all() >>> print(data_split["mask"]) tensor([[ True, True, True, True, True, True, True, True, True, True], [ True, True, True, True, True, False, False, False, False, False]]) >>> data_split = split_trajectories(data, trajectory_key="trajectory") >>> print(data_split) TensorDict( fields={ mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False)