# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.fromtypingimportAny,Callable,Dict,List,Mapping,Optionalimportnumpyasnpfromdatasetsimportload_datasetfromtorch.utils.dataimportDatasetfromtorchtune.config._utilsimport_get_component_from_pathfromtorchtune.dataimport(ChatFormat,CROSS_ENTROPY_IGNORE_IDX,get_openai_messages,get_sharegpt_messages,Message,validate_messages,)fromtorchtune.datasets._packedimportPackedDatasetfromtorchtune.modules.tokenizersimportModelTokenizer
[docs]classChatDataset(Dataset):""" Class that supports any custom dataset with multiturn conversations. The general flow from loading a sample to tokenized prompt is: load sample -> apply transform -> foreach turn{format into template -> tokenize} If the column/key names differ from the expected names in the :class:`~torchtune.data.ChatFormat`, then the ``column_map`` argument can be used to provide this mapping. Use ``convert_to_messages`` to prepare your dataset into the Llama2 chat format and roles:: [ Message( role=<system|user|assistant>, content=<message>, ), ... ] This class supports multi-turn conversations. If a tokenizer sample with multiple turns does not fit within ``max_seq_len`` then it is truncated. Args: tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset`` (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) convert_to_messages (Callable[[Mapping[str, Any]], List[Message]]): function that keys into the desired field in the sample and converts to a list of :class:`~torchtune.data.Message` that follows the Llama format with the expected keys chat_format (Optional[ChatFormat]): template used to format the chat. This is used to add structured text around the actual messages, such as the [INST] tags in Llama2 and in Mistral. The extra text will still get tokenized as normal text, not as special tokens. In models like Llama3 where the tokenizer adds tags as special tokens, ``chat_format`` is not needed, unless you want to structure messages in a particular way for inference. If the placeholder variable names in the template do not match the column/key names in the dataset, use ``column_map`` to map them. For a list of all possible chat formats, check out :ref:`chat_formats`. Default: None. max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. """def__init__(self,*,tokenizer:ModelTokenizer,source:str,convert_to_messages:Callable[[Mapping[str,Any]],List[Message]],chat_format:Optional[ChatFormat]=None,max_seq_len:int,train_on_input:bool=False,**load_dataset_kwargs:Dict[str,Any],)->None:ifchat_formatisnotNoneandnotisinstance(chat_format(),ChatFormat):raiseValueError(f"chat_format must be a ChatFormat class, not {type(chat_format())}")self._tokenizer=tokenizerself._data=load_dataset(source,**load_dataset_kwargs)self._convert_to_messages=convert_to_messagesself.chat_format=chat_formatself.max_seq_len=max_seq_lenself.train_on_input=train_on_inputdef__len__(self):returnlen(self._data)def__getitem__(self,index:int)->Dict[str,List[int]]:sample=self._data[index]returnself._prepare_sample(sample)def_prepare_sample(self,sample:Mapping[str,Any])->Dict[str,List[int]]:messages=self._convert_to_messages(sample,self.train_on_input)ifself.chat_formatisnotNone:messages=self.chat_format.format(messages)validate_messages(messages)tokens,mask=self._tokenizer.tokenize_messages(messages,max_seq_len=self.max_seq_len)# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokenslabels=list(np.where(mask,CROSS_ENTROPY_IGNORE_IDX,tokens))assertlen(tokens)==len(labels)return{"tokens":tokens,"labels":labels}
[docs]defchat_dataset(*,tokenizer:ModelTokenizer,source:str,conversation_style:str,chat_format:Optional[str]=None,max_seq_len:int,train_on_input:bool=False,packed:bool=False,**load_dataset_kwargs:Dict[str,Any],)->ChatDataset:""" Build a configurable dataset with conversations. This method should be used to configure a custom chat dataset from the yaml config instead of using :class:`~torchtune.datasets.ChatDataset` directly, as it is made to be config friendly. Args: tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset`` (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) conversation_style (str): string specifying expected style of conversations in the dataset for automatic conversion to the :class:`~torchtune.data.Message` structure. Supported styles are: "sharegpt", "openai" chat_format (Optional[str]): full import path of :class:`~torchtune.data.ChatFormat` class used to format the messages. See the description in :class:`~torchtune.datasets.ChatDataset` for more details. For a list of all possible chat formats, check out :ref:`chat_formats`. Default: None. max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. Examples: >>> from torchtune.datasets import chat_dataset >>> dataset = chat_dataset( ... tokenizer=tokenizer, ... source="HuggingFaceH4/no_robots", ... conversation_style="sharegpt", ... chat_format="torchtune.data.ChatMLFormat", ... max_seq_len=2096, ... train_on_input=True ... ) This can also be accomplished via the yaml config:: dataset: _component_: torchtune.datasets.chat_dataset source: HuggingFaceH4/no_robots conversation_style: sharegpt chat_format: torchtune.data.ChatMLFormat max_seq_len: 2096 train_on_input: True Returns: ChatDataset or PackedDataset: the configured :class:`~torchtune.datasets.ChatDataset` or :class:`~torchtune.datasets.PackedDataset` if ``packed=True`` Raises: ValueError: if the conversation format is not supported """ifconversation_style=="sharegpt":convert_to_messages=get_sharegpt_messageselifconversation_style=="openai":convert_to_messages=get_openai_messageselse:raiseValueError(f"Unsupported conversation style: {conversation_style}")ds=ChatDataset(tokenizer=tokenizer,source=source,convert_to_messages=convert_to_messages,chat_format=_get_component_from_path(chat_format)ifchat_formatisnotNoneelseNone,max_seq_len=max_seq_len,train_on_input=train_on_input,**load_dataset_kwargs,)return(PackedDataset(ds,max_seq_len=max_seq_len,padding_idx=tokenizer.pad_id)ifpackedelseds)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.