

torchtune.rlhf.truncate_sequence_at_first_stop_token(sequences: Tensor, stop_tokens: Tensor, fill_value: int = 0) Tuple[Tensor, Tensor][source]

Truncates sequence(s) after the first stop token and pads with fill_value.

  • sequences (torch.Tensor) – tensor of shape [batch_size, sequence_length] or [sequence_length].

  • stop_tokens (torch.Tensor) – tensor containing stop tokens.

  • fill_value (int) – value to pad the sequence with after the first stop token, usually pad_id.


A tuple of two tensors with the same shape as sequences:
  • padding_mask (torch.Tensor): a bool tensor where True indicates the token has been truncated.

  • sequences (torch.Tensor) a tensor of truncated and padded sequences.

Return type:

Tuple[torch.Tensor, torch.Tensor]


>>> stop_token_ids = torch.tensor([2, 869])
>>> fill_value = 0
>>> sequences = torch.tensor(
>>>     [
>>>         [869, 30, 869],
>>>         [2, 30, 869],
>>>         [869, 30, 2],
>>>         [50, 30, 869],
>>>         [13, 30, 2],
>>>         [13, 30, 5],
>>>         [13, 2, 20],
>>>         [13, 2, 2],
>>>         [2, 2, 2],
>>>     ]
>>> )
>>> eos_mask, truncated_sequences = rlhf.truncate_sequence_at_first_stop_token(
>>>     sequences, stop_token_ids, fill_value
>>> )
>>> eos_mask
>>> torch.tensor([
>>>         [False, True, True],
>>>         [False, True, True],
>>>         [False, True, True],
>>>         [False, False, False],
>>>         [False, False, False],
>>>         [False, False, False],
>>>         [False, False, True],
>>>         [False, False, True],
>>>         [False, True, True],
>>>     ]
>>> )
>>> truncated_sequences
>>> torch.tensor([
>>>         [869, 0, 0],
>>>         [2, 0, 0],
>>>         [869, 0, 0],
>>>         [50, 30, 869],
>>>         [13, 30, 2],
>>>         [13, 30, 5],
>>>         [13, 2, 0],
>>>         [13, 2, 0],
>>>         [2, 0, 0],
>>>     ]
>>> )


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources