truncate_sequence_at_first_stop_token¶
- torchtune.rlhf.truncate_sequence_at_first_stop_token(sequences: Tensor, stop_tokens: Tensor, fill_value: int = 0) Tuple[Tensor, Tensor] [source]¶
Truncates sequence(s) after the first stop token and pads with
fill_value
.- Parameters:
sequences (torch.Tensor) – tensor of shape [batch_size, sequence_length] or [sequence_length].
stop_tokens (torch.Tensor) – tensor containing stop tokens.
fill_value (int) – value to pad the sequence with after the first stop token, usually
pad_id
.
- Returns:
- A tuple of two tensors with the same shape as
sequences
: padding_mask (torch.Tensor): a bool tensor where True indicates the token has been truncated.
sequences (torch.Tensor) a tensor of truncated and padded sequences.
- A tuple of two tensors with the same shape as
- Return type:
Tuple[torch.Tensor, torch.Tensor]
Example
>>> stop_token_ids = torch.tensor([2, 869]) >>> fill_value = 0 >>> sequences = torch.tensor( >>> [ >>> [869, 30, 869], >>> [2, 30, 869], >>> [869, 30, 2], >>> [50, 30, 869], >>> [13, 30, 2], >>> [13, 30, 5], >>> [13, 2, 20], >>> [13, 2, 2], >>> [2, 2, 2], >>> ] >>> ) >>> eos_mask, truncated_sequences = rlhf.truncate_sequence_at_first_stop_token( >>> sequences, stop_token_ids, fill_value >>> ) >>> eos_mask >>> torch.tensor([ >>> [False, True, True], >>> [False, True, True], >>> [False, True, True], >>> [False, False, False], >>> [False, False, False], >>> [False, False, False], >>> [False, False, True], >>> [False, False, True], >>> [False, True, True], >>> ] >>> ) >>> truncated_sequences >>> torch.tensor([ >>> [869, 0, 0], >>> [2, 0, 0], >>> [869, 0, 0], >>> [50, 30, 869], >>> [13, 30, 2], >>> [13, 30, 5], >>> [13, 2, 0], >>> [13, 2, 0], >>> [2, 0, 0], >>> ] >>> )