torchrec.datasets¶
Torchrec Datasets
Torchrec contains two popular recys datasets, the Kaggle/Criteo Display Advertising Dataset and the MovieLens 20M Dataset.
Additionally, it contains a RandomDataset, which is useful to generate random data in the same format as the above.
Lastly, it contains scripts and utilities for pre-processing, loading, etc.
Example:
from torchrec.datasets.criteo import criteo_kaggle
datapipe = criteo_terabyte(
("/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv")
)
datapipe = dp.iter.Batcher(datapipe, 100)
datapipe = dp.iter.Collator(datapipe)
batch = next(iter(datapipe))
torchrec.datasets.criteo¶
- class torchrec.datasets.criteo.BinaryCriteoUtils¶
Bases:
object
Utility functions used to preprocess, save, load, partition, etc. the Criteo dataset in a binary (numpy) format.
- static get_file_row_ranges_and_remainder(lengths: List[int], rank: int, world_size: int, start_row: int = 0, last_row: Optional[int] = None) Tuple[Dict[int, Tuple[int, int]], int] ¶
Given a rank, world_size, and the lengths (number of rows) for a list of files, return which files and which portions of those files (represented as row ranges - all range indices are inclusive) should be handled by the rank. Each rank will be assigned the same number of rows.
The ranges are determined in such a way that each rank deals with large continuous ranges of files. This enables each rank to reduce the amount of data it needs to read while avoiding seeks.
- Parameters:
lengths (List[int]) – A list of row counts for each file.
rank (int) – rank.
world_size (int) – world size.
- Returns:
First item is a mapping of files to the range in those files to be handled by the rank. The keys of this dict are indices. The second item is the remainder of dataset length / world size.
- Return type:
output (Tuple[Dict[int, Tuple[int, int]], int])
- static get_shape_from_npy(path: str, path_manager_key: str = 'torchrec') Tuple[int, ...] ¶
Returns the shape of an npy file using only its header.
- Parameters:
path (str) – Input npy file path.
path_manager_key (str) – Path manager key used to load from different filesystems.
- Returns:
Shape tuple.
- Return type:
shape (Tuple[int, …])
- static load_npy_range(fname: str, start_row: int, num_rows: int, path_manager_key: str = 'torchrec', mmap_mode: bool = False) ndarray ¶
Load part of an npy file.
NOTE: Assumes npy represents a numpy array of ndim 2.
- Parameters:
fname (str) – path string to npy file.
start_row (int) – starting row from the npy file.
num_rows (int) – number of rows to get from the npy file.
path_manager_key (str) – Path manager key used to load from different filesystems.
- Returns:
- numpy array with the desired range of data from the
supplied npy file.
- Return type:
output (np.ndarray)
- static shuffle(input_dir_labels_and_dense: str, input_dir_sparse: str, output_dir_shuffled: str, rows_per_day: Dict[int, int], output_dir_full_set: Optional[str] = None, days: int = 24, int_columns: int = 13, sparse_columns: int = 26, path_manager_key: str = 'torchrec', random_seed: int = 0) None ¶
Shuffle the dataset. Expects the files to be in .npy format and the data to be split by day and by dense, sparse and label data. Dense data must be in: day_x_dense.npy Sparse data must be in: day_x_sparse.npy Labels data must be in: day_x_labels.npy
The dataset will be reconstructed, shuffled and then split back into separate dense, sparse and labels files.
This will only shuffle the first DAYS-1 days as the training set. The final day will remain untouched as the validation, and training set.
- Parameters:
input_dir_labels_and_dense (str) – Input directory of labels and dense npy files.
input_dir_sparse (str) – Input directory of sparse npy files.
output_dir_shuffled (str) – Output directory for shuffled labels, dense and sparse npy files.
Dict[int (rows_per_day) – Number of rows in each file.
int] – Number of rows in each file.
output_dir_full_set (str) – Output directory of the full dataset, if desired.
days (int) – Number of day files.
int_columns (int) – Number of columns with dense features.
sparse_columns (int) – Total number of categorical columns.
path_manager_key (str) – Path manager key used to load from different filesystems.
random_seed (int) – Random seed used for the random.shuffle operator.
- static sparse_to_contiguous(in_files: List[str], output_dir: str, frequency_threshold: int = 3, columns: int = 26, path_manager_key: str = 'torchrec', output_file_suffix: str = '_contig_freq.npy') None ¶
Convert all sparse .npy files to have contiguous integers. Store in a separate .npy file. All input files must be processed together because columns can have matching IDs between files. Hence, they must be transformed together. Also, the transformed IDs are not unique between columns. IDs that appear less than frequency_threshold amount of times will be remapped to have a value of 1.
Example transformation, frequency_threshold of 2: day_0_sparse.npy | col_0 | col_1 | —————– | abc | xyz | | iop | xyz |
day_1_sparse.npy | col_0 | col_1 | —————– | iop | tuv | | lkj | xyz |
day_0_sparse_contig.npy | col_0 | col_1 | —————– | 1 | 2 | | 2 | 2 |
day_1_sparse_contig.npy | col_0 | col_1 | —————– | 2 | 1 | | 1 | 2 |
- Parameters:
List[str] (in_files) – Input directory of npy files.
output_dir (str) – Output directory of processed npy files.
frequency_threshold – IDs occurring less than this frequency will be remapped to a value of 1.
path_manager_key (str) – Path manager key used to load from different filesystems.
- Returns:
None.
- static tsv_to_npys(in_file: str, out_dense_file: str, out_sparse_file: str, out_labels_file: str, dataset_name: str = 'criteo_1tb', path_manager_key: str = 'torchrec') None ¶
Convert one Criteo tsv file to three npy files: one for dense (np.float32), one for sparse (np.int32), and one for labels (np.int32).
The tsv file is expected to be part of the Criteo 1TB Click Logs Dataset (“criteo_1tb”) or the Criteo Kaggle Display Advertising Challenge dataset (“criteo_kaggle”).
For the “criteo_kaggle” test set, we set the labels to -1 representing filler data, because label data is not included in the “criteo_kaggle” test set.
- Parameters:
in_file (str) – Input tsv file path.
out_dense_file (str) – Output dense npy file path.
out_sparse_file (str) – Output sparse npy file path.
out_labels_file (str) – Output labels npy file path.
dataset_name (str) – The dataset name. “criteo_1tb” or “criteo_kaggle” is expected.
path_manager_key (str) – Path manager key used to load from different filesystems.
- Returns:
None.
- class torchrec.datasets.criteo.CriteoIterDataPipe(paths: ~typing.Iterable[str], *, row_mapper: ~typing.Optional[~typing.Callable[[~typing.List[str]], ~typing.Any]] = <function _default_row_mapper>, **open_kw)¶
Bases:
IterDataPipe
IterDataPipe that can be used to stream either the Criteo 1TB Click Logs Dataset (https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/) or the Kaggle/Criteo Display Advertising Dataset (https://www.kaggle.com/c/criteo-display-ad-challenge/) from the source TSV files.
- Parameters:
paths (Iterable[str]) – local paths to TSV files that constitute the Criteo dataset.
row_mapper (Optional[Callable[[List[str]], Any]]) – function to apply to each split TSV line.
open_kw – options to pass to underlying invocation of iopath.common.file_io.PathManager.open.
Example:
datapipe = CriteoIterDataPipe( ("/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv") ) datapipe = dp.iter.Batcher(datapipe, 100) datapipe = dp.iter.Collator(datapipe) batch = next(iter(datapipe))
- class torchrec.datasets.criteo.InMemoryBinaryCriteoIterDataPipe(stage: str, dense_paths: List[str], sparse_paths: List[str], labels_paths: List[str], batch_size: int, rank: int, world_size: int, drop_last: Optional[bool] = False, shuffle_batches: bool = False, shuffle_training_set: bool = False, shuffle_training_set_random_seed: int = 0, mmap_mode: bool = False, hashes: Optional[List[int]] = None, path_manager_key: str = 'torchrec')¶
Bases:
IterableDataset
Datapipe designed to operate over binary (npy) versions of Criteo datasets. Loads the entire dataset into memory to prevent disk speed from affecting throughout. Each rank reads only the data for the portion of the dataset it is responsible for.
The torchrec/datasets/scripts/npy_preproc_criteo.py script can be used to convert the Criteo tsv files to the npy files expected by this dataset.
- Parameters:
stage (str) – “train”, “val”, or “test”.
dense_paths (List[str]) – List of path strings to dense npy files.
sparse_paths (List[str]) – List of path strings to sparse npy files.
labels_paths (List[str]) – List of path strings to labels npy files.
batch_size (int) – batch size.
rank (int) – rank.
world_size (int) – world size.
shuffle_batches (bool) – Whether to shuffle batches
hashes (Optional[int]) – List of max categorical feature value for each feature. Length of this list should be CAT_FEATURE_COUNT.
path_manager_key (str) – Path manager key used to load from different filesystems.
Example:
template = "/home/datasets/criteo/1tb_binary/day_{}_{}.npy" datapipe = InMemoryBinaryCriteoIterDataPipe( dense_paths=[template.format(0, "dense"), template.format(1, "dense")], sparse_paths=[template.format(0, "sparse"), template.format(1, "sparse")], labels_paths=[template.format(0, "labels"), template.format(1, "labels")], batch_size=1024, rank=torch.distributed.get_rank(), world_size=torch.distributed.get_world_size(), ) batch = next(iter(datapipe))
- torchrec.datasets.criteo.criteo_kaggle(path: str, *, row_mapper: ~typing.Optional[~typing.Callable[[~typing.List[str]], ~typing.Any]] = <function _default_row_mapper>, **open_kw) IterDataPipe ¶
Kaggle/Criteo Display Advertising Dataset
- Parameters:
path (str) – local path to train or test dataset file.
row_mapper (Optional[Callable[[List[str]], Any]]) – function to apply to each split TSV line.
open_kw – options to pass to underlying invocation of iopath.common.file_io.PathManager.open.
Example:
train_datapipe = criteo_kaggle( "/home/datasets/criteo_kaggle/train.txt", ) example = next(iter(train_datapipe)) test_datapipe = criteo_kaggle( "/home/datasets/criteo_kaggle/test.txt", ) example = next(iter(test_datapipe))
- torchrec.datasets.criteo.criteo_terabyte(paths: ~typing.Iterable[str], *, row_mapper: ~typing.Optional[~typing.Callable[[~typing.List[str]], ~typing.Any]] = <function _default_row_mapper>, **open_kw) IterDataPipe ¶
Criteo 1TB Click Logs Dataset
- Parameters:
paths (Iterable[str]) – local paths to TSV files that constitute the Criteo 1TB dataset.
row_mapper (Optional[Callable[[List[str]], Any]]) – function to apply to each split TSV line.
open_kw – options to pass to underlying invocation of iopath.common.file_io.PathManager.open.
Example:
datapipe = criteo_terabyte( ("/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv") ) datapipe = dp.iter.Batcher(datapipe, 100) datapipe = dp.iter.Collator(datapipe) batch = next(iter(datapipe))
torchrec.datasets.movielens¶
- torchrec.datasets.movielens.movielens_20m(root: str, *, include_movies_data: bool = False, row_mapper: ~typing.Optional[~typing.Callable[[~typing.List[str]], ~typing.Any]] = <function _default_row_mapper>, **open_kw) IterDataPipe ¶
MovieLens 20M Dataset :param root: local path to root directory containing MovieLens 20M dataset files. :type root: str :param include_movies_data: if True, adds movies data to each line. :type include_movies_data: bool :param row_mapper: function to apply to each split line. :type row_mapper: Optional[Callable[[List[str]], Any]] :param open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open.
Example:
datapipe = movielens_20m("/home/datasets/ml-20") datapipe = dp.iter.Batch(datapipe, 100) datapipe = dp.iter.Collate(datapipe) batch = next(iter(datapipe))
- torchrec.datasets.movielens.movielens_25m(root: str, *, include_movies_data: bool = False, row_mapper: ~typing.Optional[~typing.Callable[[~typing.List[str]], ~typing.Any]] = <function _default_row_mapper>, **open_kw) IterDataPipe ¶
MovieLens 25M Dataset :param root: local path to root directory containing MovieLens 25M dataset files. :type root: str :param include_movies_data: if True, adds movies data to each line. :type include_movies_data: bool :param row_mapper: function to apply to each split line. :type row_mapper: Optional[Callable[[List[str]], Any]] :param open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open.
Example:
datapipe = movielens_25m("/home/datasets/ml-25") datapipe = dp.iter.Batch(datapipe, 100) datapipe = dp.iter.Collate(datapipe) batch = next(iter(datapipe))
torchrec.datasets.random¶
- class torchrec.datasets.random.RandomRecDataset(keys: List[str], batch_size: int, hash_size: Optional[int] = 100, hash_sizes: Optional[List[int]] = None, ids_per_feature: Optional[int] = 2, ids_per_features: Optional[List[int]] = None, num_dense: int = 50, manual_seed: Optional[int] = None, num_batches: Optional[int] = None, num_generated_batches: int = 10, min_ids_per_feature: Optional[int] = None, min_ids_per_features: Optional[List[int]] = None)¶
Bases:
IterableDataset
[Batch
]Random iterable dataset used to generate batches for recommender systems (RecSys). Currently produces unweighted sparse features only. TODO: Add weighted sparse features.
- Parameters:
keys (List[str]) – List of feature names for sparse features.
batch_size (int) – batch size.
hash_size (Optional[int]) – Max sparse id value. All sparse IDs will be taken modulo this value.
hash_sizes (Optional[List[int]]) – Max sparse id value per feature in keys. Each sparse ID will be taken modulo the corresponding value from this argument. Note, if this is used, hash_size will be ignored.
ids_per_feature (int) – Number of IDs per sparse feature.
ids_per_features (int) – Number of IDs per sparse feature in each key. Note, if this is used, ids_per_feature will be ignored.
num_dense (int) – Number of dense features.
manual_seed (int) – Seed for deterministic behavior.
num_batches – (Optional[int]): Num batches to generate before raising StopIteration
int (num_generated_batches) – Num batches to cache. If num_batches > num_generated batches, then we will cycle to the first generated batch. If this value is negative, batches will be generated on the fly.
min_ids_per_feature (int) – Minimum number of IDs per features.
Example:
dataset = RandomRecDataset( keys=["feat1", "feat2"], batch_size=16, hash_size=100_000, ids_per_feature=1, num_dense=13, ), example = next(iter(dataset))
torchrec.datasets.utils¶
- class torchrec.datasets.utils.Batch(dense_features: torch.Tensor, sparse_features: torchrec.sparse.jagged_tensor.KeyedJaggedTensor, labels: torch.Tensor)¶
Bases:
Pipelineable
- dense_features: Tensor¶
- labels: Tensor¶
- record_stream(stream: Stream) None ¶
See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
- sparse_features: KeyedJaggedTensor¶
- to(device: device, non_blocking: bool = False) Batch ¶
Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, to might return self or a copy of self. So please remember to use to with the assignment operator, for example, in = in.to(new_device).
- class torchrec.datasets.utils.Limit(datapipe: IterDataPipe, limit: int)¶
Bases:
IterDataPipe
- class torchrec.datasets.utils.LoadFiles(datapipe: Iterable[str], mode: str = 'b', length: int = - 1, path_manager_key: str = 'torchrec', **open_kw)¶
Bases:
IterDataPipe
[Tuple
[str
,IOBase
]]Taken and adapted from torch.utils.data.datapipes.iter.LoadFilesFromDisk
TODO: Merge this back or replace this with something in core Datapipes lib
- class torchrec.datasets.utils.ParallelReadConcat(*datapipes: ~torch.utils.data.datapipes.datapipe.IterDataPipe, dp_selector: ~typing.Callable[[~typing.Sequence[~torch.utils.data.datapipes.datapipe.IterDataPipe]], ~typing.Sequence[~torch.utils.data.datapipes.datapipe.IterDataPipe]] = <function _default_dp_selector>)¶
Bases:
IterDataPipe
Iterable DataPipe that concatenates multiple Iterable DataPipes. When used with a DataLoader, assigns a subset of datapipes to each DataLoader worker to allow for parallel reading. :param datapipes: IterDataPipe instances to read from. :param dp_selector: function that each DataLoader worker would use to determine the subset of datapipes :param to read from.:
Example:
datapipes = [ criteo_terabyte( (f"/home/local/datasets/criteo/shard_{idx}.tsv",), ) .batch(100) .collate() for idx in range(4) ] dataloader = DataLoader( ParallelReadConcat(*datapipes), num_workers=4, batch_size=None )
- class torchrec.datasets.utils.ReadLinesFromCSV(datapipe: IterDataPipe[Tuple[str, IOBase]], skip_first_line: bool = False, **kw)¶
Bases:
IterDataPipe
- torchrec.datasets.utils.idx_split_train_val(datapipe: ~torch.utils.data.datapipes.datapipe.IterDataPipe, train_perc: float, decimal_places: int = 2, key_fn: ~typing.Callable[[int], int] = <function _default_key_fn>) Tuple[IterDataPipe, IterDataPipe] ¶
- torchrec.datasets.utils.rand_split_train_val(datapipe: IterDataPipe, train_perc: float, random_seed: int = 0) Tuple[IterDataPipe, IterDataPipe] ¶
Via uniform random sampling, generates two IterDataPipe instances representing disjoint train and val splits of the given IterDataPipe.
- Parameters:
datapipe (IterDataPipe) – datapipe to split.
train_perc (float) – value in range (0.0, 1.0) specifying target proportion of datapipe samples to include in train split. Note that the actual proportion is not guaranteed to match train_perc exactly.
random_seed (int) – determines split membership for a given sample and train_perc. Use the same value across calls to generate consistent splits.
Example:
datapipe = criteo_terabyte( ("/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv") ) train_datapipe, val_datapipe = rand_split_train_val(datapipe, 0.75) train_batch = next(iter(train_datapipe)) val_batch = next(iter(val_datapipe))
- torchrec.datasets.utils.safe_cast(val: T, dest_type: Callable[[T], T], default: T) T ¶
- torchrec.datasets.utils.train_filter(key_fn: Callable[[int], int], train_perc: float, decimal_places: int, idx: int) bool ¶
- torchrec.datasets.utils.val_filter(key_fn: Callable[[int], int], train_perc: float, decimal_places: int, idx: int) bool ¶