Shortcuts

Source code for torch_xla.distributed.parallel_loader

import itertools
import queue
import threading
import torch
import torch_xla
import torch_xla.debug.profiler as xp
import torch_xla.utils.keyd_queue as kq
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm


class PerDeviceQueue(object):

  def __init__(self, device, loader_prefetch_size, device_prefetch_size):
    self.device = device
    self.cpu_loader_queue = kq.Queue(maxsize=loader_prefetch_size)
    self.queue = kq.Queue(maxsize=device_prefetch_size)
    self.close_queue_count = itertools.count()


class PerDeviceLoader(object):

  def __init__(self, loader, device):
    self._loader = loader
    self._device = device
    self._mark_step_batch_count = loader.batches_per_execution - 1
    self._batches_yielded = 0

  def __iter__(self):
    return self

  def __next__(self):
    return self.next()

  def __len__(self):
    return self._loader.per_device_samples()

  def next(self):
    if xp.get_tracer_marked_step():
      xp.set_tracer_marked_step(False)
      self._batches_yielded += 1
    else:
      if self._mark_step_batch_count <= self._batches_yielded:
        self._batches_yielded = 0
        xm.mark_step()
      else:
        self._batches_yielded += 1

    item = self._loader.next_item(self._device)
    if item is None:
      if not self._loader._exception_queue.empty():
        raise self._loader._exception_queue.get()
      xm.mark_step()
      raise StopIteration
    return item


class ParallelLoader(object):
  """Wraps an existing PyTorch DataLoader with background data upload.

  Args:
    cpu_loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
      wrapped.
    devices (`torch.device`...): The list of devices where the data has to be
      sent. The i-th sample returned by the `loader` will be sent to `devices[i
      % len(devices)]`.
    batchdim (int, optional): The dimension which is holding the batch size.
      Default: 0
    loader_prefetch_size (int, optional): The max capacity of the queue used by
      the thread which is reading samples from the `loader`, to be processed by
      the worker threads which upload data to the devices.
      Default: 16
    device_prefetch_size (int, optional): The max size of the per-device queues,
      where the worker threads deposit tensors which have already been sent to
      devices.
      Default: 8
    host_to_device_transfer_threads (int, optional): The number of threads that
      work in parallel to transfer data from loader queue to device queue.
      Default: 1
    input_sharding (ShardingSpec, Dict(str, ShardingSpec), optional): Sharding
      spec to apply to compatible input tensors after loading.
      Default: None
  """

  def __init__(self,
               cpu_loader,
               devices,
               batchdim=0,
               batches_per_execution=1,
               loader_prefetch_size=16,
               device_prefetch_size=8,
               host_to_device_transfer_threads=1,
               input_sharding=None):
    self._cpu_loader = cpu_loader
    self._devices = [torch.device(x) for x in devices]
    self._batchdim = batchdim
    self._batches_per_execution = batches_per_execution
    self._done = False
    self._queues = dict()
    self._exception_queue = queue.Queue()
    self._input_sharding = input_sharding
    for device in self._devices:
      self._queues[device] = PerDeviceQueue(device, loader_prefetch_size,
                                            device_prefetch_size)
    thread = threading.Thread(target=self._loader_worker)
    thread.daemon = True
    thread.start()
    for dqueue in self._queues.values():
      for i in range(host_to_device_transfer_threads):
        thread = threading.Thread(
            target=self._worker,
            args=(
                dqueue,
                host_to_device_transfer_threads,
            ))
        thread.daemon = True
        thread.start()

  def per_device_loader(self, device):
    """Retrieves the loader iterator object for the given device.

    Args:
      device (`torch.device`): The device whole loader is being requested.

    Returns:
      The loader iterator object for the `device`. This is not a
      `torch.utils.data.DataLoader` interface, but a Python iterator which
      returns the same tensor data structure as returned by the wrapped
      `torch.utils.data.DataLoader`, but residing on XLA devices.
    """
    return PerDeviceLoader(self, torch.device(device))

  def per_device_samples(self):
    return len(self._loader) // len(self._devices)

  def next_item(self, device):
    dqueue = self._queues[device]
    return dqueue.queue.get()

  def close(self):
    self._done = True
    for dqueue in self._queues.values():
      dqueue.queue.close()
      dqueue.cpu_loader_queue.close()

  @property
  def batches_per_execution(self):
    return self._batches_per_execution

  def _loader_worker(self):
    queues = list(self._queues.values())
    data_iter = enumerate(self._cpu_loader)
    batch = []
    while not self._done:
      try:
        _, data = next(data_iter)
      except StopIteration:
        break
      batch.append(data)
      if len(batch) == len(self._devices):
        for queue_no, device_batch in enumerate(batch):
          queues[queue_no].cpu_loader_queue.put(device_batch)
        batch = []
    for dqueue in queues:
      dqueue.cpu_loader_queue.close_write()

  def _get_batch(self, dqueue):
    batch = []
    while len(batch) < dqueue.queue.max_size():
      item = dqueue.cpu_loader_queue.get()
      if item is None:
        break
      batch.append(item)
    return batch

  def send_cpu_data_to_device(self, batches, device):
    """Move batch to device.
    Args:
      batch -> List(torch.Tensor), List(Dict(str: torch.Tensor)): Input batch
        present in the cpu memory
      device: TPU device where the batch should be moved
    
    Returns:
      result -> List(torch.Tensor), Dict(str: torch.Tensor): Returns a dict if the
        input batch is a dict. Otherwise, returns a list of torch.Tensor.
    """
    result = None
    if isinstance(self._input_sharding, dict):
      if not isinstance(batches[0], dict):
        raise ValueError(
            f"input batch should be a dict when input sharding is a dict.")
      result = []
      for batch in batches:
        xla_batch = {}
        missing_keys = []
        for key, tensor in batch.items():
          assert type(tensor) == torch.Tensor
          sharding_spec = None
          if self._input_sharding:
            if key not in self._input_sharding:
              missing_keys.append(key)
              continue
            sharding_spec = self._input_sharding[key]

          # xla_tensor is a list of tensors.
          xla_tensor = xm.send_cpu_data_to_device(tensor, device, sharding_spec)
          xla_batch[key] = xla_tensor[0]
        if len(missing_keys) != 0:
          # Returning exception as raising in the dataloading thread doesn't surface the problem in the main thread.
          raise KeyError(
              f"Keys: {missing_keys} are missing from input_sharding.")
        result.append(xla_batch)
    else:
      result = xm.send_cpu_data_to_device(batches, device, self._input_sharding)
    return result

  def _worker(self, dqueue, host_to_device_transfer_threads):
    device = torch.device(dqueue.device)
    while True:
      batch = self._get_batch(dqueue)
      if not batch:
        break
      try:
        batch = self.send_cpu_data_to_device(batch, device)
      except Exception as e:
        # _worker is being run in a daemon thread, raise the error
        # will not work. Put the error in an error queue instead.
        self._exception_queue.put(e)
        break
      for data in batch:
        dqueue.queue.put(data)
    close_queue_count = next(dqueue.close_queue_count)
    if close_queue_count == host_to_device_transfer_threads - 1:
      dqueue.queue.close_write()


[docs]class MpDeviceLoader(object): """Wraps an existing PyTorch DataLoader with background data upload. This class should only be using with multi-processing data parallelism. It will wrap the dataloader passed in with ParallelLoader and return the per_device_loader for the current device. Args: loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be wrapped. device (`torch.device`...): The device where the data has to be sent. kwargs: Named arguments for the `ParallelLoader` constructor. Example: >>> device = torch_xla.device() >>> train_device_loader = MpDeviceLoader(train_loader, device) """ def __init__(self, loader, device, **kwargs): self._loader = loader self._device = device self._parallel_loader_kwargs = kwargs def __iter__(self): parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs) return parallel_loader.per_device_loader(self._device) def __len__(self): return len(self._loader)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources