Shortcuts

Source code for torchrl.data.rlhf.dataset

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util
import os
from pathlib import Path

from typing import Sequence, Type

import torch

from tensordict import TensorDict, TensorDictBase

from tensordict.utils import NestedKey
from torchrl._utils import logger as torchrl_logger
from torchrl.data.replay_buffers import (
    SamplerWithoutReplacement,
    TensorDictReplayBuffer,
    TensorStorage,
)

_has_transformers = importlib.util.find_spec("transformers") is not None
_has_datasets = importlib.util.find_spec("datasets") is not None


[docs]class TokenizedDatasetLoader: """Loads a tokenizes dataset, and caches a memory-mapped copy of it. Args: split (str): One of ``"train"`` or ``"valid"``. max_length (int): the maximum sequence length. dataset_name (str): the name of the dataset. tokenizer_fn (callable): the tokeinizing method constructor, such as :class:`torchrl.data.rlhf.TensorDictTokenizer`. When called, it should return a :class:`tensordict.TensorDict` instance or a dictionary-like structure with the tokenized data. pre_tokenization_hook (callable, optional): called on the Dataset before tokenization. It should return a modified Dataset object. The intended use is for carrying out tasks that require modifying the dataset as a whole as opposed to modifying individual datapoints, for example discarding certain datapoints based on a particular condition. Tokenization and other "elementwise" operations on the data are performed by the process function which is mapped over the dataset. root_dir (path, optional): the path where the datasets are stored. Defaults to ``"$HOME/.cache/torchrl/data"`` from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk` will be used. Otherwise, :func:`datasets.load_dataset` will be used. Defaults to ``False``. valid_size (int, optional): the size of the validation dataset (if split starts with ``"valid"``) will be truncated to this value. Defaults to 2000 items. num_workers (int, optional): number of workers for :meth:`datasets.dataset.map` which is called during tokenization. Defaults to ``max(os.cpu_count() // 2, 1)``. tokenizer_class (type, optional): A tokenizer class, such as :class:`~transformers.AutoTokenizer` (default). tokenizer_model_name (str, optional): The model from which the vocabulary should be gathered. Defaults to ``"gpt2"``. The dataset will be stored in ``<root_dir>/<split>/<max_length>/``. Examples: >>> from torchrl.data.rlhf import TensorDictTokenizer >>> from torchrl.data.rlhf.reward import pre_tokenization_hook >>> split = "train" >>> max_length = 550 >>> dataset_name = "CarperAI/openai_summarize_comparisons" >>> loader = TokenizedDatasetLoader( ... split, ... max_length, ... dataset_name, ... TensorDictTokenizer, ... pre_tokenization_hook=pre_tokenization_hook, ... ) >>> dataset = loader.load() >>> print(dataset) TensorDict( fields={ attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([185068]), device=None, is_shared=False) """ def __init__( self, split, max_length, dataset_name, tokenizer_fn: Type[TensorDictTokenizer], pre_tokenization_hook=None, root_dir=None, from_disk=False, valid_size: int = 2000, num_workers: int = None, tokenizer_class=None, tokenizer_model_name=None, ): self.split = split self.max_length = max_length self.dataset_name = dataset_name self.tokenizer_fn = tokenizer_fn self.pre_tokenization_hook = pre_tokenization_hook self.root_dir = root_dir self.from_disk = from_disk self.valid_size = valid_size if num_workers is None: num_workers = max(os.cpu_count() // 2, 1) self.num_workers = num_workers if tokenizer_class is None: from transformers import AutoTokenizer tokenizer_class = AutoTokenizer if tokenizer_model_name is None: tokenizer_model_name = "gpt2" self.make_tokenizer( tokenizer_class=AutoTokenizer, tokenizer_model_name=tokenizer_model_name ) def make_tokenizer(self, *, tokenizer_class, tokenizer_model_name): tokenizer = tokenizer_class.from_pretrained(tokenizer_model_name) tokenizer.pad_token = tokenizer.eos_token self.tokenizer = tokenizer
[docs] def load(self): """Loads a pre-processed, memory-mapped dataset if it exists, and creates it otherwise.""" root_dir = self.root_dir max_length = self.max_length split = self.split if root_dir is None: root_dir = Path(os.environ.get("HOME")) / ".cache/torchrl/data/" os.makedirs(root_dir, exist_ok=True) root_dir = Path(root_dir) data_dir = root_dir / str(Path(self.dataset_name).name).split("-")[0] data_dir_total = data_dir / split / str(max_length) # search for data torchrl_logger.info(f"Looking for data in {data_dir_total}") if os.path.exists(data_dir_total): dataset = TensorDict.load_memmap(data_dir_total) return dataset dataset = self._load_dataset() dataset = self._tokenize(dataset) prefix = (split, str(max_length)) return self.dataset_to_tensordict( dataset, data_dir=data_dir, prefix=prefix, valid_mask_key="valid_sample" )[prefix]
def _load_dataset(self): """Loads a text dataset from ``datasets``. Returns: a dataset of type ``datasets.Dataset``. """ if not _has_datasets: raise ImportError( "preproc_data requires the datasets package to be installed." ) from datasets import load_dataset, load_from_disk if self.from_disk: dataset = load_from_disk(str(self.dataset_name))[self.split] else: dataset = load_dataset(self.dataset_name, split=self.split) if self.split.startswith("valid"): # reduce size of validation dataset dataset = dataset.select(range(self.valid_size)) if self.pre_tokenization_hook is not None: dataset = self.pre_tokenization_hook(dataset) return dataset def _tokenize( self, dataset, excluded_features: Sequence[str] | None = None, ): """Preprocesses a text dataset from ``datasets``. Args: dataset (datasets.Dataset): a dataset loaded using :meth:`~.load_dataset`. excluded_features (sequence of str, optional): the features to exclude once tokenization is complete. Defaults to ``{"text", "prompt", "label", "valid_sample"}``. Returns: a dataset of type ``datasets.Dataset``. """ if not _has_transformers: raise ImportError("The transformers library is missing.") num_workers = self.num_workers if excluded_features is None: excluded_features = {"text", "prompt", "label", "valid_sample"} tokenizer = self.tokenizer # tokenize the dataset # TODO: replace this by TensorDict.map dataset = dataset.map( self.tokenizer_fn( tokenizer, max_length=self.max_length, return_tensordict=False ), desc="Tokenizing...", num_proc=num_workers, batched=True, ) if not isinstance(dataset, TensorDictBase): dataset_dict = dataset.to_dict() if excluded_features: dataset_dict = { key: value for key, value in dataset_dict.items() if key not in excluded_features } dataset = TensorDict.from_dict(dataset_dict) elif excluded_features: dataset = dataset.exclude(*excluded_features) # keep non empty rows (i.e. where at least one token is not eos) if "valid_sample" in dataset.keys(): mask = dataset.get("valid_sample") dataset = dataset[mask] return dataset
[docs] @staticmethod def dataset_to_tensordict( dataset: "datasets.Dataset" | TensorDict, # noqa: F821 data_dir: Path, prefix: NestedKey = None, features: Sequence[str] = None, batch_dims=1, valid_mask_key=None, ): """Convers a dataset to a memory-mapped TensorDict. If the dataset is already a :class:`TensorDict` instance, it is simply converted to a memory-mapped TensorDict. Otherwise, the dataset is expected to have a ``features`` attribute which is a sequence of strings indicating the features that can be found in the dataset. If it does not, the ``features`` must be passed explicitely to this function. Args: dataset (datasets.Dataset, TensorDict or equivalent): a dataset to convert to a memory-mapped TensorDict. If ``features`` is ``None``, it must have a ``features`` attribute with the list of keys to write in the tensordict. data_dir (Path or equivalent): directory where the data should be written. prefix (NestedKey, optional): the prefix of the dataset location. This can be used to differentiate several copies of a same dataset that have undergone different preprocessings. features (sequence of str, optional): a sequence of str indicating the features that can be found in the dataset. batch_dims (int, optional): the number of batch_dimensions of the data (ie number of dimensions along which the tensordict can be indexed). Defaults to 1. valid_mask_key (NestedKey, optional): if provided, this entry will be tentatively gathered and used to filder the data. Defaults to ``None`` (ie, no filter key). Returns: a TensorDict containing memory-mapped tensors with the dataset. Examples: >>> from datasets import Dataset >>> import tempfile >>> data = Dataset.from_dict({"tokens": torch.randint(20, (10, 11)), "labels": torch.zeros(10, 11)}) >>> with tempfile.TemporaryDirectory() as tmpdir: ... data_memmap = TokenizedDatasetLoader.dataset_to_tensordict( ... data, data_dir=tmpdir, prefix=("some", "prefix"), features=["tokens", "labels"] ... ) ... print(data_memmap) TensorDict( fields={ some: TensorDict( fields={ prefix: TensorDict( fields={ labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False), tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) """ if not isinstance(dataset, TensorDict): if features is None: features = dataset.features if prefix is None: prefix = () data_dict = {key: torch.as_tensor(dataset[key]) for key in features} out = TensorDict.from_dict(data_dict, batch_dims=batch_dims) else: out = dataset if valid_mask_key is not None and valid_mask_key in out.keys( include_nested=True ): out = out[out.get(valid_mask_key)] out = TensorDict({prefix: out}, []) out.memmap_(prefix=data_dir) return out
[docs]def create_infinite_iterator(iterator): """Iterates indefinitely over an iterator.""" while True: yield from iterator
[docs]def get_dataloader( batch_size: int, block_size: int, tensorclass_type: Type, device: torch.device, dataset_name: str | None = None, infinite: bool = True, prefetch: int = 0, split: str = "train", root_dir: str | None = None, from_disk: bool = False, num_workers: int | None = None, ): """Creates a dataset and returns a dataloader from it. Args: batch_size (int): the batch size of the dataloader samples. block_size (int): the maximum length of a sequence in the dataloader. tensorclass_type (tensorclass class): a tensorclass with a :meth:`from_dataset` method that must accept three keyword arguments: ``split`` (see below), ``max_length`` which is the block size to be used for training and ``dataset_name``, a string indicating the dataset. The ``root_dir`` and ``from_disk`` arguments should also be supported. device (torch.device or equivalent): the device where the samples should be cast. dataset_name (str, optional): the dataset name. If not provided and if the tensorclass supports it, a default dataset name will be gathered for the tensorclass being used. infinite (bool, optional): if ``True``, the iteration will be infinite such that ``next(iterator)`` will always return a value. Defaults to ``True``. prefetch (int, optional): the number of items to be prefetched if multithreaded dataloading is being used. split (str, optional): the data split. Either ``"train"`` or ``"valid"``. Defaults to ``"train"``. root_dir (path, optional): the path where the datasets are stored. Defaults to ``"$HOME/.cache/torchrl/data"`` from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk` will be used. Otherwise, :func:`datasets.load_dataset` will be used. Defaults to ``False``. num_workers (int, optional): number of workers for :meth:`datasets.dataset.map` which is called during tokenization. Defaults to ``max(os.cpu_count() // 2, 1)``. Examples: >>> from torchrl.data.rlhf.reward import PairwiseDataset >>> dataloader = get_dataloader( ... batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu") >>> for d in dataloader: ... print(d) ... break PairwiseDataset( chosen_data=RewardData( attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([256]), device=cpu, is_shared=False), rejected_data=RewardData( attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([256]), device=cpu, is_shared=False), batch_size=torch.Size([256]), device=cpu, is_shared=False) """ data = tensorclass_type.from_dataset( split=split, dataset_name=dataset_name, max_length=block_size, root_dir=root_dir, from_disk=from_disk, num_workers=num_workers, ) out = TensorDictReplayBuffer( storage=TensorStorage(data), collate_fn=lambda x: x.as_tensor().to(device, non_blocking=True), sampler=SamplerWithoutReplacement(drop_last=True), batch_size=batch_size, prefetch=prefetch, ) if infinite: return create_infinite_iterator(out) return out
[docs]class TensorDictTokenizer: """Factory for a process function that applies a tokenizer over a text example. Args: tokenizer (tokenizer from transformers library): the tokenizer to use. max_length (int): maximum length of the sequence. key (str, optional): the key where to find the text. Defaults to ``"text"``. padding (str, optional): type of padding. Defaults to ``"max_length"``. truncation (bool, optional): whether the sequences should be truncated to max_length. return_tensordict (bool, optional): if ``True``, a TensoDict is returned. Otherwise, a the orignal data will be returned. device (torch.device, optional): the device where to store the data. This option is ignored if ``return_tensordict=False``. See transformers library for more information about tokenizers: Padding and truncation: `<https://huggingface.co/docs/transformers/pad_truncation>`_ Returns: a :class:`tensordict.TensorDict` instance with the same batch-size as the input data. Examples: >>> from transformers import AutoTokenizer >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> tokenizer.pad_token = 100 >>> process = TensorDictTokenizer(tokenizer, max_length=10) >>> # example with a single input >>> example = {"text": "I am a little worried"} >>> process(example) TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> # example with a multiple inputs >>> example = {"text": ["Let me reassure you", "It will be ok"]} >>> process(example) TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) """ def __init__( self, tokenizer, max_length, key="text", padding="max_length", truncation=True, return_tensordict=True, device=None, ): self.tokenizer = tokenizer self.max_length = max_length self.key = key self.padding = padding self.truncation = truncation self.return_tensordict = return_tensordict self.device = device def __call__(self, sample): input = sample[self.key] tokenized_sample = self.tokenizer( input, max_length=self.max_length, padding=self.padding, truncation=self.truncation, ) batch_size = [] if isinstance(input, str) else [len(input)] if self.return_tensordict: return TensorDict.from_dict( dict(tokenized_sample), batch_size=batch_size, device=self.device ) return tokenized_sample

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources