Shortcuts

Source code for torchtune.datasets._sft

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Dict, Mapping, Optional

import numpy as np

from datasets import load_dataset
from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages
from torchtune.modules.transforms import Transform


[docs]class SFTDataset(Dataset): """ Primary class for creating any dataset for supervised fine-tuning either from Hugging Face Hub, local files, or remote files. This class supports instruct, chat, tool, or multimodal data for fine-tuning. At a high level, this class will load the data from source and apply the following pre-processing steps when a sample is retrieved: 1. Dataset-specific transform. This is typically unique to each dataset and extracts the necessary columns into torchtune's :class:`~torchtune.data.Message` format, a standardized API for all model tokenizers. 2. Model-specific transform or tokenization with optional prompt template All datasets are formatted into a list of :class:`~torchtune.data.Message` because for fine-tuning, datasets can be considered as "conversations" with the model, or AI assistant. Thus, we can standardize all text content as messages in a conversation assigned to a role: - ``"system"`` messages contain the system prompt - ``"user"`` messages contain the input prompt into the model - ``"assistant"`` messages are the response of the model and what you actually want to train for and compute loss directly against - ``"ipython"`` messages are the return from a tool call Chat datasets are multiple rounds of user-assistant messages. Instruct datasets are typically a single round involving a specific instruction and the model's response. Tool datasets are a type of chat dataset that includes ipython messages. Multimodal datasets are a type of chat dataset that incorporates media into the user messages. The :class:`~torchtune.data.Message` forms the core data unit that all tokenizer APIs expect. The key component of this class that ensures any dataset is transformed into this format is the ``message_transform``. This is a callable class that takes in a sample dictionary - typically a single row from the source dataset - that processes the sample in any configurable way to output a list of messages:: [ Message( role=<system|user|assistant|ipython>, content=<message>, ), ... ] For any custom dataset, use the ``message_transform`` to contain all pre-processing to return the list of messages. Any model-specific pre-processing that needs to happen can be configured with the ``model_transform`` parameter. This is another callable class that contains any custom logic tied to the model you are fine-tuning and will carry over to inference. For example, text + image multimodal datasets requires processing the images in a way specific to the vision encoder being used by the model and is agnostic to the specific dataset. Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer` can be treated as a ``model_transform`` since it uses the model-specific tokenizer to transform the list of messages outputted from the ``message_transform`` into tokens used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer` into ``model_transform``. Tokenizers handle prompt templating, if configured. Args: source (str): path to dataset repository on Hugging Face. For local datasets, define source as the data file type (e.g. "json", "csv", "text") and pass in the filepath in ``data_files``. See `Hugging Face's <https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_ ``load_dataset`` for more details. message_transform (Transform): callable that keys into the desired fields in the sample and converts text content to a list of :class:`~torchtune.data.Message`. It is expected that the final list of messages are stored in the ``"messages"`` key. model_transform (Transform): callable that applies model-specific pre-processing to the sample after the list of messages is created from ``message_transform``. This includes tokenization and any modality-specific transforms. It is expected to return at minimum ``"tokens"`` and ``"mask"`` keys. filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more details. **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging Face's `API ref <https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset>`_ for more details. """ def __init__( self, *, source: str, message_transform: Transform, model_transform: Transform, filter_fn: Optional[Callable] = None, **load_dataset_kwargs: Dict[str, Any], ) -> None: self._message_transform = message_transform self._model_transform = model_transform self._data = load_dataset(source, **load_dataset_kwargs) if filter_fn is not None: self._data = self._data.filter(filter_fn) def __len__(self): return len(self._data) def __getitem__(self, index: int) -> Dict[str, Any]: sample = self._data[index] return self._prepare_sample(sample) def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: transformed_sample = self._message_transform(sample) if "messages" in transformed_sample: validate_messages(transformed_sample["messages"]) tokenized_dict = self._model_transform(transformed_sample) if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): keys_str = ", ".join(tokenized_dict.keys()) error_message = ( "model_transform returned the following keys: " f"{keys_str}. Must return 'tokens' and 'mask' as keys." ) raise ValueError(error_message) # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens tokenized_dict["labels"] = list( np.where( tokenized_dict["mask"], CROSS_ENTROPY_IGNORE_IDX, tokenized_dict["tokens"], ) ) assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) return tokenized_dict

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources