Shortcuts

Source code for torchrl.data.rlhf.prompt

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import Optional

import torch
from tensordict import tensorclass, TensorDict

from torchrl.data.rlhf.dataset import TensorDictTokenizer, TokenizedDatasetLoader

DEFAULT_DATASET = "CarperAI/openai_summarize_tldr"


[docs]@tensorclass class PromptData: """A prompt dataset.""" input_ids: torch.Tensor attention_mask: torch.Tensor prompt_rindex: torch.Tensor labels: Optional[torch.Tensor] = None logits: Optional[torch.Tensor] = None loss: Optional[torch.Tensor] = None def mask_label(self, pad_token_id=50256): _, block_size = self.input_ids.shape attention_mask = ( torch.arange(block_size, device=self.prompt_rindex.device) < self.prompt_rindex[:, None] ).to(torch.int64) input_ids = torch.where(attention_mask == 1, self.input_ids, pad_token_id) return self.__class__( input_ids=input_ids, attention_mask=attention_mask, prompt_rindex=self.prompt_rindex, loss=self.loss, batch_size=[], )
[docs] @classmethod def from_dataset( cls, split, dataset_name=None, max_length=550, root_dir=None, from_disk=False, num_workers: int | None = None, ): """Returns a :class:`PromptData` from a dataset name. Args: split (str): ``"train"`` or ``"valid"`` depending on the data split needed. dataset_name (str, optional): name of the dataset to be processed. Defaults to ``"CarperAI/openai_summarize_comparisons"``. max_length (int, optional): maximum length of the dataset sequenes. Defaults to 550. root_dir (path, optional): the path where the datasets are stored. Defaults to ``"$HOME/.cache/torchrl/data"`` from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk` will be used. Otherwise, :func:`datasets.load_dataset` will be used. Defaults to ``False``. num_workers (int, optional): number of workers for :meth:`datasets.dataset.map` which is called during tokenization. Defaults to ``max(os.cpu_count() // 2, 1)``. Returns: a :class:`PromptData` instance containing a memory-mapped version of the required dataset. Examples: >>> data = PromptData.from_dataset("train") >>> print(data) PromptDataTLDR( attention_mask=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), prompt_rindex=MemoryMappedTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False), labels=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), logits=None, loss=None, batch_size=torch.Size([116722]), device=None, is_shared=False) >>> # data can be sampled from using regular indexing >>> sub_data = data[:3] """ dataset_name = dataset_name if dataset_name is not None else DEFAULT_DATASET loader = TokenizedDatasetLoader( split, max_length, dataset_name, PromptTensorDictTokenizer, root_dir=root_dir, from_disk=from_disk, num_workers=num_workers, ) data = loader.load() return cls(**data, labels=data["input_ids"], batch_size=data.shape)
[docs]class PromptTensorDictTokenizer(TensorDictTokenizer): """Tokenization recipe for prompt datasets. Returns a tokenizer function, which reads an example containing a prompt and a label and tokenizes them. Args: tokenizer (tokenizer from transformers library): the tokenizer to use. max_length (int): maximum length of the sequence. key (str, optional): the key where to find the text. Defaults to ``"prompt"``. padding (str, optional): type of padding. Defaults to ``"max_length"``. truncation (bool, optional): whether the sequences should be truncated to max_length. return_tensordict (bool, optional): if ``True``, a TensoDict is returned. Otherwise, a the orignal data will be returned. device (torch.device, optional): the device where to store the data. This option is ignored if ``return_tensordict=False``. The :meth:`__call__` method of this class will execute the following operations: - Read the ``prompt`` string contacted with the ``label`` string and tokenize them. The results will be stored in the ``"input_ids"`` TensorDict entry. - Write a ``"prompt_rindex"`` entry with the index of the last valid token from the prompt. - Write a ``"valid_sample"`` which identifies which entry in the tensordict has eough toknens to meet the ``max_length`` criterion. - Return a :class:`tensordict.TensorDict` instance with tokenized inputs. The tensordict batch-size will match the batch-size of the input. Examples: >>> from transformers import AutoTokenizer >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> tokenizer.pad_token = tokenizer.eos_token >>> example = { ... "prompt": ["This prompt is long enough to be tokenized.", "this one too!"], ... "label": ["Indeed it is.", 'It might as well be.'], ... } >>> fn = PromptTensorDictTokenizer(tokenizer, 50) >>> print(fn(example)) TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False), prompt_rindex: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), valid_sample: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) """ def __init__( self, tokenizer, max_length, key="prompt", padding="max_length", truncation=True, return_tensordict=True, device=None, ): self.tokenizer = tokenizer self.max_length = max_length self.key = key self.padding = padding self.truncation = truncation self.return_tensordict = return_tensordict self.device = device def __call__(self, sample): tokenizer = self.tokenizer max_length = self.max_length tokenized_prompts = tokenizer( sample[self.key], max_length=max_length, truncation=True ) prompt_rindex = [len(prompt) - 1 for prompt in tokenized_prompts["input_ids"]] tokenized_example = tokenizer( [ prompt + label for prompt, label in zip(sample[self.key], sample["label"]) ], max_length=max_length, padding=self.padding, truncation=self.truncation, ) tokenized_example["prompt_rindex"] = prompt_rindex # drop any examples whose total length when tokenized exceeds block size # with recommended block size of 550, this is only ~0.1% of available examples. # NOTE: to mark as discarded we just save the mask as we cannot change the shape here tokenized_example["valid_sample"] = [True] * len(tokenized_example["input_ids"]) for i, input_ids in enumerate(tokenized_example["input_ids"]): if input_ids[-1] != tokenizer.eos_token_id: tokenized_example["valid_sample"][i] = False if self.return_tensordict: return TensorDict.from_dict(dict(tokenized_example), device=self.device) return tokenized_example

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources