Shortcuts

Source code for torch_xla.utils.cached_dataset

from __future__ import division
from __future__ import print_function

import json
import io
import os
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.gcsfs as gcs


def _index_split(index, split_size, split_count):
  parts = []
  while True:
    if parts:
      part = str(index % split_size)
    else:
      part = '{}.pt'.format(index)
    parts.append(part)
    index = index // split_size
    if index == 0:
      break
  while len(parts) < split_count:
    parts.append('0')
  parts.reverse()
  return parts


def _save_metadata(path, **kwargs):
  mpath = os.path.join(path, 'METADATA')
  jdata = json.dumps(kwargs)
  gcs.generic_write(jdata, mpath, makedirs=True)


def _load_metadata(path):
  mpath = os.path.join(path, 'METADATA')
  jdata = gcs.generic_read(mpath).decode()
  return json.loads(jdata)


[docs]class CachedDataset(torch.utils.data.Dataset): """Wraps an existing `torch.utils.data.Dataset` by providing file caching. The `CachedDataset` can be used to trade the CPU/RAM resources required to process a raw dataset, with storage/network resources. Example:: train_dataset = datasets.MNIST( FLAGS.datadir, train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) train_dataset = CachedDataset(train_dataset, FLAGS.dscache_dir) The `CachedDataset` will transparently cache the original `Dataset` samples, so that every run after the first, will not trigger any more CPU/RAM usage related to the raw samples processing. Once a `CachedDataset` is fully cached, it can be exported (ie, tar.gz) and used in different machines. Just unpack the tar.gz and pass `None` as original `Dataset`: Example:: train_dataset = CachedDataset(None, FLAGS.dscache_dir) To fully cache `CachedDataset` just run the `warmup()` API. A `CachedDataset` saved on GCS has the advantage to be able to be used from different machines without explicit exporting. Args: data_set (torch.utils.data.Dataset): The raw `torch.utils.data.Dataset` to be cached. It can be set to `None` in case all the input samples are stored within the `path` folder. path (string): The path where the dataset samples should be stored/loaded. The `path` needs to be writeable, unless all the samples are already stored. The `path` can be a GCS path (prefixed with `gs://`). max_files_per_folder (int): The maximum amount of files to be stored within a single folder. If `data_set` is `None` this value is ignored and taken from the cached metadata. Default: 1000 compress (bool): Whether the saved samples should be compressed. Compression saves space at the expense of CPU required to compress/decompress. If `data_set` is `None` this value is ignored and taken from the cached metadata. Default: True """ def __init__(self, data_set, path, max_files_per_folder=1000, compress=True): super(CachedDataset, self).__init__() self._data_set = data_set self._path = path if data_set is not None: self._max_files_per_folder = max_files_per_folder self._compress = compress self._count = len(data_set) if xm.is_master_ordinal(local=not gcs.is_gcs_path(path)): _save_metadata( path, count=self._count, compress=self._compress, max_files_per_folder=self._max_files_per_folder) else: meta = _load_metadata(path) self._max_files_per_folder = meta['max_files_per_folder'] self._count = meta['count'] self._compress = meta['compress'] self._split_count = len( _index_split(self._count, self._max_files_per_folder, 0)) def _index_path(self, index): return os.path.join( self._path, *_index_split(index, self._max_files_per_folder, self._split_count)) def _save_sample(self, data, path): bio = io.BytesIO() torch.save(data, bio, _use_new_zipfile_serialization=self._compress) gcs.generic_write(bio.getvalue(), path, makedirs=True) def _load_sample(self, path): try: data = gcs.generic_read(path) return torch.load(io.BytesIO(data)) except: pass def warmup(self): for index in range(0, self._count): self.__getitem__(index) def __len__(self): return self._count def __getitem__(self, index): path = self._index_path(index) data = self._load_sample(path) if data is None: if self._data_set is None: raise RuntimeError( 'Source dataset not provided and sample {} is missing from cache folder {}' .format(index, self._path)) data = self._data_set[index] self._save_sample(data, path) return data

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources