Shortcuts

Source code for torch_xla.utils.cached_dataset

from __future__ import division
from __future__ import print_function

import io
import os
import torch
import torch_xla
import torch_xla.utils.gcsfs as gcs


def _index_split(index, split_size, split_count):
  parts = []
  while index > 0:
    findex = index % split_size if parts else index
    parts.append(str(findex))
    index = index // split_size
  while len(parts) < split_count:
    parts.append('0')
  parts.reverse()
  return parts


[docs]class CachedDataset(torch.utils.data.Dataset): """Wraps an existing `torch.utils.data.Dataset` by providing file caching. The `CachedDataset` can be used to trade the CPU/RAM resources required to process a raw dataset, with storage/network resources. Example:: train_dataset = datasets.MNIST( os.path.join(FLAGS.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) train_dataset = CachedDataset(train_dataset, FLAGS.dscache_dir) Args: data_set (torch.utils.data.Dataset): The raw `torch.utils.data.Dataset` to be cached. path (string): The path where the dataset samples should be stored/loaded. The `path` needs to be writeable, unless all the samples are already stored. The `path` can be a GCS path (prefixed with `gs://`). max_files_per_folder (int): The maximum amount of files to be stored within a single folder. Default: 1000 compress (bool): Whether the saved samples should be compressed. Compression saves space at the expense of CPU required to compress/decompress. Default: True """ def __init__(self, data_set, path, max_files_per_folder=1000, compress=True): super(CachedDataset, self).__init__() self._data_set = data_set self._path = path self._max_files_per_folder = max_files_per_folder self._compress = compress self._count = len(data_set) self._split_count = len(_index_split(self._count, max_files_per_folder, 0)) def _index_path(self, index): return os.path.join( self._path, *_index_split(index, self._max_files_per_folder, self._split_count)) def _save_sample(self, data, index): bio = io.BytesIO() torch.save(data, bio, _use_new_zipfile_serialization=self._compress) path = self._index_path(index) gcs.generic_write(bio.getvalue(), path, makedirs=True) def _load_sample(self, index): path = self._index_path(index) try: data = gcs.generic_read(path) return torch.load(io.BytesIO(data)) except: pass def warmup(self): for index in range(0, self._count): self.__getitem__(index) def __len__(self): return self._count def __getitem__(self, index): data = self._load_sample(index) if data is None: data = self._data_set[index] self._save_sample(data, index) return data

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources