Shortcuts

pad_sequence

class tensordict.pad_sequence(list_of_tensordicts: Sequence[T], pad_dim: int = 0, padding_value: float = 0.0, out: Optional[T] = None, device: Optional[Union[device, str, int]] = None, return_mask: bool | tensordict._nestedkey.NestedKey = False)

Pads a list of tensordicts in order for them to be stacked together in a contiguous format.

Parameters:
  • list_of_tensordicts (List[TensorDictBase]) – the list of instances to pad and stack.

  • pad_dim (int, optional) – the pad_dim indicates the dimension to pad all the keys in the tensordict. Defaults to 0.

  • padding_value (number, optional) – the padding value. Defaults to 0.0.

  • out (TensorDictBase, optional) – if provided, the destination where the data will be written.

  • return_mask (bool or NestedKey, optional) – if True, a “masks” entry will be returned. If return_mask is a nested key (string or tuple of strings), it will be return the masks and be used as the key for the masks entry. It contains a tensordict with the same structure as the stacked tensordict where every entry contains the mask of valid values with size torch.Size([stack_len, *new_shape]), where new_shape[pad_dim] = max_seq_length and the rest of the new_shape matches the previous shape of the contained tensors.

Examples

>>> list_td = [
...     TensorDict({"a": torch.zeros((3, 8)), "b": torch.zeros((6, 8))}, batch_size=[]),
...     TensorDict({"a": torch.zeros((5, 8)), "b": torch.zeros((6, 8))}, batch_size=[]),
...     ]
>>> padded_td = pad_sequence(list_td, return_mask=True)
>>> print(padded_td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 4, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 5, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        masks: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.bool, is_shared=False),
                b: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources