Llama2Tokenizer¶
- class torchtune.models.llama2.Llama2Tokenizer(path: str)[source]¶
Llama2’s implementation of the SentencePiece tokenizer. Llama2Tokenizer does not include any additional special tokens. The prompt template described in https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2/ describes [INST][/INST] and <<SYS>><</SYS>> as special tokens but these are not registered as unique ids and are tokenized as normal text. When using this tokenizer on the pre-trained model for inference, it is strongly encouraged to apply the
Llama2ChatFormat
to your data beforehand to add the [INST] and <<SYS>> for optimal performance. For fine-tuning, this is not required. For more details, see https://pytorch.org/torchtune/main/tutorials/chat.html#tokenizing-prompt-templates-special-tokens.- Parameters:
path (str) – Path to pretrained SentencePiece tokenizer file.
Examples
>>> tokenizer = Llama2Tokenizer("/path/to/spm_model") >>> tokenized_text = tokenizer.encode("Hello world!", add_bos=True, add_eos=True) >>> print(tokenized_text) [1, 31587, 29644, 102, 2]
- tokenize_messages(messages: List[Message], max_seq_len: Optional[int] = None) Tuple[List[int], List[bool]] [source]¶
Tokenize a list of messages one at a time then concatenate them, returning a list of tokens and a list of masks.
Note: llama2 sentencepiece has problems where in general encode(s1 + s2) != encode(s1) + encode(s2) due to whitespace handling. We can get around this by prepending s2 with a known token and slicing the beginning off the tokenized s2.
Example
>>> tokenizer = Llama2Tokenizer(tokenizer_path) >>> messages = [ Message(role="system", content="system message\n", masked=True), Message(role="user", content="user prompt\n", masked=True), Message(role="assistant", content="assistant response\n"), ] # tokenize_messages encodes messages separately and concats >>> tokenizer.tokenize_messages(messages, max_seq_len)[0] [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2]
# Same result as encoding the full string in one go >>> tokenizer.encode(‘’.join([message.content for message in messages])) [1, 1788, 2643, 13, 1792, 9508, 13, 465, 22137, 2933, 2]