Shortcuts

InstructDataset

class torchtune.datasets.InstructDataset(tokenizer: ModelTokenizer, source: str, template: InstructTemplate, transform: Optional[Callable] = None, column_map: Optional[Dict[str, str]] = None, train_on_input: bool = False, max_seq_len: Optional[int] = None, **load_dataset_kwargs: Dict[str, Any])[source]

Class that supports any custom dataset with instruction-based prompts and a configurable template.

The general flow from loading a sample to tokenized prompt is: load sample -> apply transform -> format into template -> tokenize

If the column/key names differ from the expected names in the InstructTemplate, then the column_map argument can be used to provide this mapping.

Masking of the prompt during training is controlled by the train_on_input flag, which is set to False by default. - If train_on_input is True, the prompt is used during training and contributes to the loss. - If train_on_input is False, the prompt is masked out (tokens replaced with -100)

Parameters:
  • tokenizer (ModelTokenizer) – Tokenizer used by the model that implements the tokenize_messages method.

  • source (str) – path string of dataset, anything supported by Hugging Face’s load_dataset (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)

  • template (InstructTemplate) – template used to format the prompt. If the placeholder variable names in the template do not match the column/key names in the dataset, use column_map to map them.

  • transform (Optional[Callable]) – transform to apply to the sample before formatting to the template. Default is None.

  • column_map (Optional[Dict[str, str]]) – a mapping from the expected placeholder names in the template to the column/key names in the sample. If None, assume these are identical.

  • train_on_input (bool) – Whether the model is trained on the prompt or not. Default is False.

  • max_seq_len (Optional[int]) – Maximum number of tokens in the returned input and label token id lists. Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.

  • **load_dataset_kwargs (Dict[str, Any]) – additional keyword arguments to pass to load_dataset.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources