get_dataloader¶
- class torchrl.data.get_dataloader(batch_size: int, block_size: int, tensorclass_type: Type, device: torch.device, dataset_name: str | None = None, infinite: bool = True, prefetch: int = 0, split: str = 'train', root_dir: str | None = None, from_disk: bool = False, num_workers: int | None = None)[source]¶
Creates a dataset and returns a dataloader from it.
- Parameters:
batch_size (int) – the batch size of the dataloader samples.
block_size (int) – the maximum length of a sequence in the dataloader.
tensorclass_type (tensorclass class) – a tensorclass with a
from_dataset()
method that must accept three keyword arguments:split
(see below),max_length
which is the block size to be used for training anddataset_name
, a string indicating the dataset. Theroot_dir
andfrom_disk
arguments should also be supported.device (torch.device or equivalent) – the device where the samples should be cast.
dataset_name (str, optional) – the dataset name. If not provided and if the tensorclass supports it, a default dataset name will be gathered for the tensorclass being used.
infinite (bool, optional) – if
True
, the iteration will be infinite such thatnext(iterator)
will always return a value. Defaults toTrue
.prefetch (int, optional) – the number of items to be prefetched if multithreaded dataloading is being used.
split (str, optional) – the data split. Either
"train"
or"valid"
. Defaults to"train"
.root_dir (path, optional) – the path where the datasets are stored. Defaults to
"$HOME/.cache/torchrl/data"
from_disk (bool, optional) – if
True
,datasets.load_from_disk()
will be used. Otherwise,datasets.load_dataset()
will be used. Defaults toFalse
.num_workers (int, optional) – number of workers for
datasets.dataset.map()
which is called during tokenization. Defaults tomax(os.cpu_count() // 2, 1)
.
Examples
>>> from torchrl.data.rlhf.reward import PairwiseDataset >>> dataloader = get_dataloader( ... batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu") >>> for d in dataloader: ... print(d) ... break PairwiseDataset( chosen_data=RewardData( attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([256]), device=cpu, is_shared=False), rejected_data=RewardData( attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([256]), device=cpu, is_shared=False), batch_size=torch.Size([256]), device=cpu, is_shared=False)