Shortcuts

SFTDataset

class torchtune.datasets.SFTDataset(*, source: str, message_transform: Transform, model_transform: Transform, filter_fn: Optional[Callable] = None, **load_dataset_kwargs: Dict[str, Any])[source]

Primary class for creating any dataset for supervised fine-tuning either from Hugging Face Hub, local files, or remote files. This class supports instruct, chat, tool, or multimodal data for fine-tuning. At a high level, this class will load the data from source and apply the following pre-processing steps when a sample is retrieved:

  1. Dataset-specific transform. This is typically unique to each dataset and extracts the necessary columns into torchtune’s Message format, a standardized API for all model tokenizers.

  2. Model-specific transform or tokenization with optional prompt template

All datasets are formatted into a list of Message because for fine-tuning, datasets can be considered as “conversations” with the model, or AI assistant. Thus, we can standardize all text content as messages in a conversation assigned to a role:

  • "system" messages contain the system prompt

  • "user" messages contain the input prompt into the model

  • "assistant" messages are the response of the model and what you actually want to train for and compute loss directly against

  • "ipython" messages are the return from a tool call

Chat datasets are multiple rounds of user-assistant messages. Instruct datasets are typically a single round involving a specific instruction and the model’s response. Tool datasets are a type of chat dataset that includes ipython messages. Multimodal datasets are a type of chat dataset that incorporates media into the user messages.

The Message forms the core data unit that all tokenizer APIs expect. The key component of this class that ensures any dataset is transformed into this format is the message_transform. This is a callable class that takes in a sample dictionary - typically a single row from the source dataset - that processes the sample in any configurable way to output a list of messages:

[
    Message(
        role=<system|user|assistant|ipython>,
        content=<message>,
    ),
    ...
]

For any custom dataset, use the message_transform to contain all pre-processing to return the list of messages.

Any model-specific pre-processing that needs to happen can be configured with the model_transform parameter. This is another callable class that contains any custom logic tied to the model you are fine-tuning and will carry over to inference. For example, text + image multimodal datasets requires processing the images in a way specific to the vision encoder being used by the model and is agnostic to the specific dataset.

Tokenization is handled by the model_transform. All ModelTokenizer can be treated as a model_transform since it uses the model-specific tokenizer to transform the list of messages outputted from the message_transform into tokens used by the model for training. Text-only datasets will simply pass the ModelTokenizer into model_transform. Tokenizers handle prompt templating, if configured.

Parameters:
  • source (str) – path to dataset repository on Hugging Face. For local datasets, define source as the data file type (e.g. “json”, “csv”, “text”) and pass in the filepath in data_files. See Hugging Face’s load_dataset for more details.

  • message_transform (Transform) – callable that keys into the desired fields in the sample and converts text content to a list of Message. It is expected that the final list of messages are stored in the "messages" key.

  • model_transform (Transform) – callable that applies model-specific pre-processing to the sample after the list of messages is created from message_transform. This includes tokenization and any modality-specific transforms. It is expected to return at minimum "tokens" and "mask" keys.

  • filter_fn (Optional[Callable]) – callable used to filter the dataset prior to any pre-processing. See the Hugging Face docs for more details.

  • **load_dataset_kwargs (Dict[str, Any]) – additional keyword arguments to pass to load_dataset. See Hugging Face’s API ref for more details.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources