SFTDataset¶
- class torchtune.datasets.SFTDataset(*, source: str, message_transform: Transform, model_transform: Transform, filter_fn: Optional[Callable] = None, **load_dataset_kwargs: Dict[str, Any])[source]¶
Primary class for creating any dataset for supervised fine-tuning either from Hugging Face Hub, local files, or remote files. This class supports instruct, chat, tool, or multimodal data for fine-tuning. At a high level, this class will load the data from source and apply the following pre-processing steps when a sample is retrieved:
Dataset-specific transform. This is typically unique to each dataset and extracts the necessary columns into torchtune’s
Message
format, a standardized API for all model tokenizers.Model-specific transform or tokenization with optional prompt template
All datasets are formatted into a list of
Message
because for fine-tuning, datasets can be considered as “conversations” with the model, or AI assistant. Thus, we can standardize all text content as messages in a conversation assigned to a role:"system"
messages contain the system prompt"user"
messages contain the input prompt into the model"assistant"
messages are the response of the model and what you actually want to train for and compute loss directly against"ipython"
messages are the return from a tool call
Chat datasets are multiple rounds of user-assistant messages. Instruct datasets are typically a single round involving a specific instruction and the model’s response. Tool datasets are a type of chat dataset that includes ipython messages. Multimodal datasets are a type of chat dataset that incorporates media into the user messages.
The
Message
forms the core data unit that all tokenizer APIs expect. The key component of this class that ensures any dataset is transformed into this format is themessage_transform
. This is a callable class that takes in a sample dictionary - typically a single row from the source dataset - that processes the sample in any configurable way to output a list of messages:[ Message( role=<system|user|assistant|ipython>, content=<message>, ), ... ]
For any custom dataset, use the
message_transform
to contain all pre-processing to return the list of messages.Any model-specific pre-processing that needs to happen can be configured with the
model_transform
parameter. This is another callable class that contains any custom logic tied to the model you are fine-tuning and will carry over to inference. For example, text + image multimodal datasets requires processing the images in a way specific to the vision encoder being used by the model and is agnostic to the specific dataset.Tokenization is handled by the
model_transform
. AllModelTokenizer
can be treated as amodel_transform
since it uses the model-specific tokenizer to transform the list of messages outputted from themessage_transform
into tokens used by the model for training. Text-only datasets will simply pass theModelTokenizer
intomodel_transform
. Tokenizers handle prompt templating, if configured.- Parameters:
source (str) – path to dataset repository on Hugging Face. For local datasets, define source as the data file type (e.g. “json”, “csv”, “text”) and pass in the filepath in
data_files
. See Hugging Face’sload_dataset
for more details.message_transform (Transform) – callable that keys into the desired fields in the sample and converts text content to a list of
Message
. It is expected that the final list of messages are stored in the"messages"
key.model_transform (Transform) – callable that applies model-specific pre-processing to the sample after the list of messages is created from
message_transform
. This includes tokenization and any modality-specific transforms. It is expected to return at minimum"tokens"
and"mask"
keys.filter_fn (Optional[Callable]) – callable used to filter the dataset prior to any pre-processing. See the Hugging Face docs for more details.
**load_dataset_kwargs (Dict[str, Any]) – additional keyword arguments to pass to
load_dataset
. See Hugging Face’s API ref for more details.