Source code for torchtune.datasets._sft
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, Mapping, Optional
import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages
from torchtune.modules.transforms import Transform
[docs]class SFTDataset(Dataset):
"""
Primary class for creating any dataset for supervised fine-tuning either from
Hugging Face Hub, local files, or remote files. This class supports instruct,
chat, tool, or multimodal data for fine-tuning. At a high level, this class
will load the data from source and apply the following pre-processing steps
when a sample is retrieved:
1. Dataset-specific transform. This is typically unique to each dataset and extracts
the necessary columns into torchtune's :class:`~torchtune.data.Message` format,
a standardized API for all model tokenizers.
2. Model-specific transform or tokenization with optional prompt template
All datasets are formatted into a list of :class:`~torchtune.data.Message`
because for fine-tuning, datasets can be considered as "conversations" with the model,
or AI assistant. Thus, we can standardize all text content as messages in a conversation assigned to
a role:
- ``"system"`` messages contain the system prompt
- ``"user"`` messages contain the input prompt into the model
- ``"assistant"`` messages are the response of the model and what you actually want
to train for and compute loss directly against
- ``"ipython"`` messages are the return from a tool call
Chat datasets are multiple rounds of user-assistant messages. Instruct datasets
are typically a single round involving a specific instruction and the model's response.
Tool datasets are a type of chat dataset that includes ipython messages. Multimodal
datasets are a type of chat dataset that incorporates media into the user messages.
The :class:`~torchtune.data.Message` forms the core data unit that all tokenizer
APIs expect. The key component of this class that ensures any dataset is transformed
into this format is the ``message_transform``. This is a callable class that takes
in a sample dictionary - typically a single row from the source dataset - that
processes the sample in any configurable way to output a list of messages::
[
Message(
role=<system|user|assistant|ipython>,
content=<message>,
),
...
]
For any custom dataset, use the ``message_transform`` to contain all pre-processing to
return the list of messages.
Any model-specific pre-processing that needs to happen can be configured with the ``model_transform``
parameter. This is another callable class that contains any custom logic tied to the
model you are fine-tuning and will carry over to inference. For example, text + image
multimodal datasets requires processing the images in a way specific to the vision
encoder being used by the model and is agnostic to the specific dataset.
Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`
can be treated as a ``model_transform`` since it uses the model-specific tokenizer to
transform the list of messages outputted from the ``message_transform`` into tokens
used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer`
into ``model_transform``. Tokenizers handle prompt templating, if configured.
Args:
source (str): path to dataset repository on Hugging Face. For local datasets,
define source as the data file type (e.g. "json", "csv", "text") and pass
in the filepath in ``data_files``. See `Hugging Face's
<https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_
``load_dataset`` for more details.
message_transform (Transform): callable that keys into the desired fields in the sample
and converts text content to a list of :class:`~torchtune.data.Message`. It is expected that the final list
of messages are stored in the ``"messages"`` key.
model_transform (Transform): callable that applies model-specific pre-processing to the sample after the list of
messages is created from ``message_transform``. This includes tokenization and any modality-specific
transforms. It is expected to return at minimum ``"tokens"`` and ``"mask"`` keys.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging
Face's `API ref <https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset>`_
for more details.
"""
def __init__(
self,
*,
source: str,
message_transform: Transform,
model_transform: Transform,
filter_fn: Optional[Callable] = None,
**load_dataset_kwargs: Dict[str, Any],
) -> None:
self._message_transform = message_transform
self._model_transform = model_transform
self._data = load_dataset(source, **load_dataset_kwargs)
if filter_fn is not None:
self._data = self._data.filter(filter_fn)
def __len__(self):
return len(self._data)
def __getitem__(self, index: int) -> Dict[str, Any]:
sample = self._data[index]
return self._prepare_sample(sample)
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
transformed_sample = self._message_transform(sample)
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])
tokenized_dict = self._model_transform(transformed_sample)
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())
error_message = (
"model_transform returned the following keys: "
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
)
raise ValueError(error_message)
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
tokenized_dict["labels"] = list(
np.where(
tokenized_dict["mask"],
CROSS_ENTROPY_IGNORE_IDX,
tokenized_dict["tokens"],
)
)
assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"])
return tokenized_dict