ChatDataset¶
- class torchtune.datasets.ChatDataset(*, tokenizer: ModelTokenizer, source: str, convert_to_messages: Callable[[Mapping[str, Any]], List[Message]], chat_format: Optional[ChatFormat] = None, max_seq_len: int, train_on_input: bool = False, **load_dataset_kwargs: Dict[str, Any])[source]¶
Class that supports any custom dataset with multiturn conversations.
The general flow from loading a sample to tokenized prompt is: load sample -> apply transform -> foreach turn{format into template -> tokenize}
If the column/key names differ from the expected names in the
ChatFormat
, then thecolumn_map
argument can be used to provide this mapping.Use
convert_to_messages
to prepare your dataset into the Llama2 chat format and roles:[ Message( role=<system|user|assistant>, content=<message>, ), ... ]
This class supports multi-turn conversations. If a tokenizer sample with multiple turns does not fit within
max_seq_len
then it is truncated.- Parameters:
tokenizer (ModelTokenizer) – Tokenizer used by the model that implements the
tokenize_messages
method.source (str) – path string of dataset, anything supported by Hugging Face’s
load_dataset
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)convert_to_messages (Callable[[Mapping[str, Any]], List[Message]]) – function that keys into the desired field in the sample and converts to a list of
Message
that follows the Llama format with the expected keyschat_format (Optional[ChatFormat]) – template used to format the chat. This is used to add structured text around the actual messages, such as the [INST] tags in Llama2 and in Mistral. The extra text will still get tokenized as normal text, not as special tokens. In models like Llama3 where the tokenizer adds tags as special tokens,
chat_format
is not needed, unless you want to structure messages in a particular way for inference. If the placeholder variable names in the template do not match the column/key names in the dataset, usecolumn_map
to map them. For a list of all possible chat formats, check out Text templates. Default: None.max_seq_len (int) – Maximum number of tokens in the returned input and label token id lists.
train_on_input (bool) – Whether the model is trained on the prompt or not. Default is False.
**load_dataset_kwargs (Dict[str, Any]) – additional keyword arguments to pass to
load_dataset
.