Shortcuts

Source code for torchtune.datasets._chat

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Dict, List, Mapping, Optional

import numpy as np

from datasets import load_dataset
from torch.utils.data import Dataset
from torchtune.config._utils import _get_component_from_path
from torchtune.data import (
    ChatFormat,
    CROSS_ENTROPY_IGNORE_IDX,
    get_openai_messages,
    get_sharegpt_messages,
    Message,
    validate_messages,
)
from torchtune.datasets._packed import PackedDataset
from torchtune.modules.tokenizers import ModelTokenizer


[docs]class ChatDataset(Dataset): """ Class that supports any custom dataset with multiturn conversations. The general flow from loading a sample to tokenized prompt is: load sample -> apply transform -> foreach turn{format into template -> tokenize} If the column/key names differ from the expected names in the :class:`~torchtune.data.ChatFormat`, then the ``column_map`` argument can be used to provide this mapping. Use ``convert_to_messages`` to prepare your dataset into the Llama2 chat format and roles:: [ Message( role=<system|user|assistant>, content=<message>, ), ... ] This class supports multi-turn conversations. If a tokenizer sample with multiple turns does not fit within ``max_seq_len`` then it is truncated. Args: tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset`` (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) convert_to_messages (Callable[[Mapping[str, Any]], List[Message]]): function that keys into the desired field in the sample and converts to a list of :class:`~torchtune.data.Message` that follows the Llama format with the expected keys chat_format (Optional[ChatFormat]): template used to format the chat. This is used to add structured text around the actual messages, such as the [INST] tags in Llama2 and in Mistral. The extra text will still get tokenized as normal text, not as special tokens. In models like Llama3 where the tokenizer adds tags as special tokens, ``chat_format`` is not needed, unless you want to structure messages in a particular way for inference. If the placeholder variable names in the template do not match the column/key names in the dataset, use ``column_map`` to map them. For a list of all possible chat formats, check out :ref:`chat_formats`. Default: None. max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. """ def __init__( self, *, tokenizer: ModelTokenizer, source: str, convert_to_messages: Callable[[Mapping[str, Any]], List[Message]], chat_format: Optional[ChatFormat] = None, max_seq_len: int, train_on_input: bool = False, **load_dataset_kwargs: Dict[str, Any], ) -> None: if chat_format is not None and not isinstance(chat_format(), ChatFormat): raise ValueError( f"chat_format must be a ChatFormat class, not {type(chat_format())}" ) self._tokenizer = tokenizer self._data = load_dataset(source, **load_dataset_kwargs) self._convert_to_messages = convert_to_messages self.chat_format = chat_format self.max_seq_len = max_seq_len self.train_on_input = train_on_input def __len__(self): return len(self._data) def __getitem__(self, index: int) -> Dict[str, List[int]]: sample = self._data[index] return self._prepare_sample(sample) def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: messages = self._convert_to_messages(sample, self.train_on_input) if self.chat_format is not None: messages = self.chat_format.format(messages) validate_messages(messages) tokens, mask = self._tokenizer.tokenize_messages( messages, max_seq_len=self.max_seq_len ) # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens labels = list(np.where(mask, CROSS_ENTROPY_IGNORE_IDX, tokens)) assert len(tokens) == len(labels) return {"tokens": tokens, "labels": labels}
[docs]def chat_dataset( *, tokenizer: ModelTokenizer, source: str, conversation_style: str, chat_format: Optional[str] = None, max_seq_len: int, train_on_input: bool = False, packed: bool = False, **load_dataset_kwargs: Dict[str, Any], ) -> ChatDataset: """ Build a configurable dataset with conversations. This method should be used to configure a custom chat dataset from the yaml config instead of using :class:`~torchtune.datasets.ChatDataset` directly, as it is made to be config friendly. Args: tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset`` (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) conversation_style (str): string specifying expected style of conversations in the dataset for automatic conversion to the :class:`~torchtune.data.Message` structure. Supported styles are: "sharegpt", "openai" chat_format (Optional[str]): full import path of :class:`~torchtune.data.ChatFormat` class used to format the messages. See the description in :class:`~torchtune.datasets.ChatDataset` for more details. For a list of all possible chat formats, check out :ref:`chat_formats`. Default: None. max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. Examples: >>> from torchtune.datasets import chat_dataset >>> dataset = chat_dataset( ... tokenizer=tokenizer, ... source="HuggingFaceH4/no_robots", ... conversation_style="sharegpt", ... chat_format="torchtune.data.ChatMLFormat", ... max_seq_len=2096, ... train_on_input=True ... ) This can also be accomplished via the yaml config:: dataset: _component_: torchtune.datasets.chat_dataset source: HuggingFaceH4/no_robots conversation_style: sharegpt chat_format: torchtune.data.ChatMLFormat max_seq_len: 2096 train_on_input: True Returns: ChatDataset or PackedDataset: the configured :class:`~torchtune.datasets.ChatDataset` or :class:`~torchtune.datasets.PackedDataset` if ``packed=True`` Raises: ValueError: if the conversation format is not supported """ if conversation_style == "sharegpt": convert_to_messages = get_sharegpt_messages elif conversation_style == "openai": convert_to_messages = get_openai_messages else: raise ValueError(f"Unsupported conversation style: {conversation_style}") ds = ChatDataset( tokenizer=tokenizer, source=source, convert_to_messages=convert_to_messages, chat_format=_get_component_from_path(chat_format) if chat_format is not None else None, max_seq_len=max_seq_len, train_on_input=train_on_input, **load_dataset_kwargs, ) return ( PackedDataset(ds, max_seq_len=max_seq_len, padding_idx=tokenizer.pad_id) if packed else ds )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources