Shortcuts

get_dataloader

class torchrl.data.get_dataloader(batch_size: int, block_size: int, tensorclass_type: Type, device: torch.device, dataset_name: str | None = None, infinite: bool = True, prefetch: int = 0, split: str = 'train', root_dir: str | None = None, from_disk: bool = False, num_workers: int | None = None)[source]

Creates a dataset and returns a dataloader from it.

Parameters:
  • batch_size (int) – the batch size of the dataloader samples.

  • block_size (int) – the maximum length of a sequence in the dataloader.

  • tensorclass_type (tensorclass class) – a tensorclass with a from_dataset() method that must accept three keyword arguments: split (see below), max_length which is the block size to be used for training and dataset_name, a string indicating the dataset. The root_dir and from_disk arguments should also be supported.

  • device (torch.device or equivalent) – the device where the samples should be cast.

  • dataset_name (str, optional) – the dataset name. If not provided and if the tensorclass supports it, a default dataset name will be gathered for the tensorclass being used.

  • infinite (bool, optional) – if True, the iteration will be infinite such that next(iterator) will always return a value. Defaults to True.

  • prefetch (int, optional) – the number of items to be prefetched if multithreaded dataloading is being used.

  • split (str, optional) – the data split. Either "train" or "valid". Defaults to "train".

  • root_dir (path, optional) – the path where the datasets are stored. Defaults to "$HOME/.cache/torchrl/data"

  • from_disk (bool, optional) – if True, datasets.load_from_disk() will be used. Otherwise, datasets.load_dataset() will be used. Defaults to False.

  • num_workers (int, optional) – number of workers for datasets.dataset.map() which is called during tokenization. Defaults to max(os.cpu_count() // 2, 1).

Examples

>>> from torchrl.data.rlhf.reward import PairwiseDataset
>>> dataloader = get_dataloader(
...     batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu")
>>> for d in dataloader:
...     print(d)
...     break
PairwiseDataset(
    chosen_data=RewardData(
        attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        rewards=None,
        end_scores=None,
        batch_size=torch.Size([256]),
        device=cpu,
        is_shared=False),
    rejected_data=RewardData(
        attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        rewards=None,
        end_scores=None,
        batch_size=torch.Size([256]),
        device=cpu,
        is_shared=False),
    batch_size=torch.Size([256]),
    device=cpu,
    is_shared=False)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources