Source code for torchtune.datasets._packed
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional
import torch
from torch.nn import functional as F
from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX, PACK_TYPE
from tqdm import tqdm
[docs]class PackedDataset(Dataset):
"""
Performs greedy sample packing on a provided dataset. This is done as a single
preprocessing step before training begins. Shuffling is done outside of this
class on packed samples with a ``Sampler`` as part of the dataloader. Currently,
this only supports in-memory map-style datasets.
The class loads, tokenizes, and packs examples on initialization - no tokenization is done during training.
The general flow on initialization is: load tokenized sample -> add to buffer ->
when buffer is long enough, add to ``self.packs``.
During training, returns self.packs[idx] as input, label, attention mask, and
position ids. The attention mask is a lower triangular block mask to prevent
samples from cross-attending within a pack. The position ids indicate the position
of each token relative to its sample within a pack. These are all padded to max
sequence length, so a batch-wise collator is not needed.
A packed sample is made up of individual smaller sequence length samples jammed together
within ``max_seq_len``. For example, if max_seq_len is 6 and there are varied
length samples::
tokens = [
[S1, S1, S1, S2, S2, pad],
[S3, S3, S4, S4, pad, pad],
...,
]
To prevent cross-contamination, the following mask would be returned for the
first pack in the example::
mask = [
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1],
]
The position ids would be::
input_pos = [
[0, 1, 2, 0, 1, 2],
[0, 1, 0, 1, 2, 3],
...,
]
The identity matrix is used in the mask for pad tokens instead of a causal mask.
For position ids for pad tokens, we simply continue to increment from the previous
sample normally.
Args:
ds (Dataset): dataset to sample pack. This should return a dictionary with field
"tokens" and "labels" containing the tokenized and label samples.
max_seq_len (int): Maximum number of tokens to pack
padding_idx (int): padding index for the tokenizer. Default is 0.
max_packs (Optional[int]): Maximum number of packs. Default is None, which will create as many
packs as possible.
split_across_pack (bool): if the last sample in a pack does not fit in ``max_seq_len``,
split the sample into the next pack, or move it entirely to the beginning of the next pack.
For pre-training, typically this is set to True for general text completion. For
fine-tuning, typically this is set to False to avoid truncating sentences in instruct
tuning. Default is False.
"""
def __init__(
self,
ds: Dataset,
*,
max_seq_len: int,
padding_idx: int = 0,
max_packs: Optional[int] = None,
split_across_pack: bool = False,
) -> None:
self.ds = ds
self.max_seq_len = max_seq_len
self.padding_idx = padding_idx
self.max_packs = max_packs
self.split_across_pack = split_across_pack
# Where final samples will be held
self.packs: List[PACK_TYPE] = []
self.previous_sample_boundary: int = 0
self._pack()
def _pack(self) -> None:
"""Iterate through the dataset. Use a buffer to hold samples until max_seq_len,
then append the buffer to self.packs as a single "packed" sample. Continue
until max_packs or end of dataset."""
# Buffer to hold samples until they are long enough to be added to self.packs
current_pack = {
"tokens": [],
"labels": [],
"input_pos": [],
"seq_lens": [],
}
# Only show progress bar on rank 0
rank = (
torch.distributed.get_rank()
if torch.distributed.is_available() and torch.distributed.is_initialized()
else 0
)
if rank == 0:
pbar = tqdm(total=len(self.ds), desc="Packing dataset", dynamic_ncols=True)
for sample in self.ds:
tokens, labels = sample["tokens"], sample["labels"]
# If the dataset outputs samples that are larger than the specified
# max_seq_len and we're unable to split it, user needs to modify
# one of the two parameters
seq_len = len(tokens)
if seq_len > self.max_seq_len and not self.split_across_pack:
raise ValueError(
f"Dataset sample is too long ({seq_len} > {self.max_seq_len}). "
"Please set `split_across_pack=True` or increase `max_seq_len`."
)
# Update the current pack
current_pack["tokens"] += tokens
current_pack["labels"] += labels
current_pack["input_pos"] += [x % self.max_seq_len for x in range(seq_len)]
current_pack["seq_lens"] += [seq_len]
# If the current pack is over the max_seq_len, add it to self.packs and
# retain any truncated or bumped samples for next pack
while (
len(current_pack["tokens"]) > self.max_seq_len
and not self._should_stop_packing()
):
current_pack = self._split_and_add_pack(current_pack)
if rank == 0:
pbar.update()
# Keep track of previous sample boundary
self.previous_sample_boundary = len(current_pack["tokens"])
if self._should_stop_packing():
break
# Handle the last pack if there's leftover and we haven't filled up the max packs
if len(current_pack["tokens"]) > 0 and (
self.max_packs is None or len(self.packs) < self.max_packs
):
# No need to handle splitting at this point so we can just add the current pack
self._add_pack(current_pack)
def _should_stop_packing(self) -> bool:
"""If max packs is set, stop packing when we reach that number."""
if self.max_packs is not None and len(self.packs) == self.max_packs:
return True
return False
def _split_and_add_pack(self, current_pack: PACK_TYPE) -> PACK_TYPE:
"""Splits the current pack at the boundary, processes it, adds it to ``self.packs`` and
returns the start of the next pack."""
if self.split_across_pack:
boundary = self.max_seq_len
# The last elem in ``seq_lens`` ensures that ``sum(seq_lens) == self.max_seq_len``
leftover_seq_len = self.max_seq_len - sum(current_pack["seq_lens"][:-1])
seq_len_padding = [leftover_seq_len] if leftover_seq_len > 0 else []
else:
boundary = self.previous_sample_boundary
# If we aren't splitting across packs, we leave out the last sample b/c
# it will go into the next pack
seq_len_padding = []
pack = {
"tokens": current_pack["tokens"][:boundary],
"labels": current_pack["labels"][:boundary],
"input_pos": current_pack["input_pos"][:boundary],
"seq_lens": current_pack["seq_lens"][:-1] + seq_len_padding,
}
# Process and add the pack
self._add_pack(pack)
# Return the length of the first sample in next pack if we are splitting across packs,
# otherwise return the length of the last sample in the current pack
next_seq_len = (
len(current_pack["tokens"][boundary:])
if self.split_across_pack
else current_pack["seq_lens"][-1]
)
return {
"tokens": current_pack["tokens"][boundary:],
"labels": current_pack["labels"][boundary:],
"input_pos": current_pack["input_pos"][boundary:],
"seq_lens": [next_seq_len],
}
def _add_pack(self, pack: PACK_TYPE) -> None:
"""Processes, pads and adds a pack to ``self.packs``."""
pack = self._convert_to_tensors(pack)
pack = self._pad_pack(pack, padding_idx=self.padding_idx)
self.packs.append(pack)
def _convert_to_tensors(self, pack: PACK_TYPE) -> PACK_TYPE:
"""Converts a pack into tensors. Pack comes in as a dict of lists and is converted to tensors."""
return {
"tokens": torch.tensor(pack["tokens"], dtype=torch.long),
"labels": torch.tensor(pack["labels"], dtype=torch.long),
"input_pos": torch.tensor(pack["input_pos"], dtype=torch.long),
"seq_lens": torch.tensor(pack["seq_lens"], dtype=torch.long),
}
def _pad_pack(self, pack: PACK_TYPE, padding_idx: int) -> PACK_TYPE:
"""Pads a pack to ``self.max_seq_len``."""
# Pad tokens
num_padding_tokens = self.max_seq_len - len(pack["tokens"])
padded_tokens = F.pad(
pack["tokens"],
(0, num_padding_tokens),
value=padding_idx,
)
# Pad labels
padded_labels = F.pad(
pack["labels"],
(0, self.max_seq_len - len(pack["labels"])),
value=CROSS_ENTROPY_IGNORE_IDX,
)
# Add padding tokens as a last seq len to ensure sum is max_seq_len
padded_seq_lens = (
torch.cat([pack["seq_lens"], torch.tensor([num_padding_tokens])])
if num_padding_tokens > 0
else pack["seq_lens"]
)
# Pad input_pos continuing the sequence from last value
# in input_pos
# e.g. [0 1 2] -> [0 1 2 3 4 5] for self.max_seq_len = 6
num_range = torch.arange(
pack["input_pos"][-1] + 1,
pack["input_pos"][-1] + self.max_seq_len - len(pack["input_pos"]) + 1,
)
# Clamp to max_seq_len - 1 to avoid out of bounds error
clamped_num_range = torch.clamp(num_range, 0, self.max_seq_len - 1)
padded_input_pos = torch.cat([pack["input_pos"], clamped_num_range])
return {
"tokens": padded_tokens,
"labels": padded_labels,
"input_pos": padded_input_pos,
"seq_lens": padded_seq_lens,
}
def __len__(self) -> int:
return len(self.packs)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
return self.packs[idx]