split_trajectories¶
- torchrl.collectors.utils.split_trajectories(rollout_tensordict: TensorDictBase, prefix=None) TensorDictBase [source]¶
A util function for trajectory separation.
Takes a tensordict with a key traj_ids that indicates the id of each trajectory.
From there, builds a B x T x … zero-padded tensordict with B batches on max duration T
- Parameters:
rollout_tensordict (TensorDictBase) – a rollout with adjacent trajectories along the last dimension.
prefix (str or tuple of str, optional) – the prefix used to read and write meta-data, such as
"traj_ids"
(the optional integer id of each trajectory) and the"mask"
entry indicating which data are valid and which aren’t. Defaults toNone
(no prefix).