PreferenceDataset¶
- class torchtune.datasets.PreferenceDataset(*, source: str, message_transform: Transform, tokenizer: ModelTokenizer, filter_fn: Optional[Callable] = None, **load_dataset_kwargs: Dict[str, Any])[source]¶
Primary class for fine-tuning via preference modelling techniques (e.g. training a preference model for RLHF, or directly optimizing a model through DPO) on a preference dataset sourced from Hugging Face Hub, local files, or remote files. This class requires the dataset to have “chosen” and “rejected” model responses. These are typically either full conversations between user and assistant in separate columns:
| chosen | rejected | |----------------------------------------|----------------------------------------| | [{"role": "user", "content": Q1}, | [{"role": "user", "content": Q1}, | | {"role": "assistant", "content": A1}] | {"role": "assistant", "content": A2}] |
or a user prompt column with separate chosen and rejected assistant reponses:
| prompt | chosen | rejected | |----------|----------|------------| | Q1 | A1 | A2 |
In the above case when the format is prompt-chosen-rejected, only single-turn interactions are supported.
At a high level, this class will load the data from source and apply the following pre-processing steps when a sample is retrieved:
Dataset-specific transform. This is typically unique to each dataset and extracts the necessary prompt and chosen/rejected columns into torchtune’s
Message
format, a standardized API for all model tokenizers.Tokenization with optional prompt template if configured
All datasets are formatted into a list of
Message
because preference datasets can be considered as chosen and rejected “conversations” with the model, or AI assistant. Thus, we can standardize all text content as messages in a conversation assigned to a role:"user"
messages contain the input prompt into the model"assistant"
messages are the response of the model and what you actually want to train for and compute loss directly against
The
Message
forms the core data unit that all tokenizer APIs expect. The key component of this class that ensures any dataset is transformed into this format is themessage_transform
. This is a callable class that takes in a sample dictionary - typically a single row from the source dataset - that processes the sample in any configurable way to output a list of messages:[ Message( role=<system|user|assistant|ipython>, content=<message>, ), ... ]
For any custom dataset, use the
message_transform
to contain all pre-processing to return the list of messages.- Parameters:
source (str) – path to dataset repository on Hugging Face. For local datasets, define source as the data file type (e.g. “json”, “csv”, “text”) and pass in the filepath in
data_files
. See Hugging Face’sload_dataset
for more details.message_transform (Transform) – callable that keys into the desired fields in the sample and converts text content to a list of
Message
. It is expected that the final list of messages are stored in the"chosen"
and"rejected"
keys.tokenizer (ModelTokenizer) – Tokenizer used by the model that implements the
tokenize_messages
method. Since PreferenceDataset only supports text data, it requires aModelTokenizer
instead of themodel_transform
inSFTDataset
.filter_fn (Optional[Callable]) – callable used to filter the dataset prior to any pre-processing. See the Hugging Face docs for more details.
**load_dataset_kwargs (Dict[str, Any]) – additional keyword arguments to pass to
load_dataset
. See Hugging Face’s API ref for more details.