Shortcuts

PackedDataset

class torchtune.datasets.PackedDataset(ds: Dataset, *, max_seq_len: int, padding_idx: int = 0, max_packs: Optional[int] = None, split_across_pack: bool = False)[source]

Performs greedy sample packing on a provided dataset. This is done as a single preprocessing step before training begins. Shuffling is done outside of this class on packed samples with a Sampler as part of the dataloader. Currently, this only supports in-memory map-style datasets.

The class loads, tokenizes, and packs examples on initialization - no tokenization is done during training.

The general flow on initialization is: load tokenized sample -> add to buffer -> when buffer is long enough, add to self.packs.

During training, returns self.packs[idx] as input, label, attention mask, and position ids. The attention mask is a lower triangular block mask to prevent samples from cross-attending within a pack. The position ids indicate the position of each token relative to its sample within a pack. These are all padded to max sequence length, so a batch-wise collator is not needed.

A packed sample is made up of individual smaller sequence length samples jammed together within max_seq_len. For example, if max_seq_len is 6 and there are varied length samples:

tokens = [
    [S1, S1, S1, S2, S2, pad],
    [S3, S3, S4, S4, pad, pad],
    ...,
]

To prevent cross-contamination, the following mask would be returned for the first pack in the example:

mask = [
    [1, 0, 0, 0, 0, 0],
    [1, 1, 0, 0, 0, 0],
    [1, 1, 1, 0, 0, 0],
    [0, 0, 0, 1, 0, 0],
    [0, 0, 0, 1, 1, 0],
    [0, 0, 0, 0, 0, 1],
]

The position ids would be:

input_pos = [
    [0, 1, 2, 0, 1, 2],
    [0, 1, 0, 1, 2, 3],
    ...,
]

The identity matrix is used in the mask for pad tokens instead of a causal mask. For position ids for pad tokens, we simply continue to increment from the previous sample normally.

Parameters:
  • ds (Dataset) – dataset to sample pack. This should return a dictionary with field “tokens” and “labels” containing the tokenized and label samples.

  • max_seq_len (int) – Maximum number of tokens to pack

  • padding_idx (int) – padding index for the tokenizer. Default is 0.

  • max_packs (Optional[int]) – Maximum number of packs. Default is None, which will create as many packs as possible.

  • split_across_pack (bool) – if the last sample in a pack does not fit in max_seq_len, split the sample into the next pack, or move it entirely to the beginning of the next pack. For pre-training, typically this is set to True for general text completion. For fine-tuning, typically this is set to False to avoid truncating sentences in instruct tuning. Default is False.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources