Shortcuts

Program Listing for File CUDAGuard.h

Return to documentation for file (c10/cuda/CUDAGuard.h)

#pragma once

#include <c10/core/DeviceType.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/InlineStreamGuard.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/impl/CUDAGuardImpl.h>

namespace c10::cuda {

// This code is kind of boilerplatey.  See Note [Whither the DeviceGuard
// boilerplate]

struct CUDAGuard {
  explicit CUDAGuard() = delete;

  explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {}

  explicit CUDAGuard(Device device) : guard_(device) {}

  // Copy is not allowed
  CUDAGuard(const CUDAGuard&) = delete;
  CUDAGuard& operator=(const CUDAGuard&) = delete;

  // Move is not allowed (there is no uninitialized state)
  CUDAGuard(CUDAGuard&& other) = delete;
  CUDAGuard& operator=(CUDAGuard&& other) = delete;
  ~CUDAGuard() = default;

  void set_device(Device device) {
    guard_.set_device(device);
  }

  void reset_device(Device device) {
    guard_.reset_device(device);
  }

  void set_index(DeviceIndex device_index) {
    guard_.set_index(device_index);
  }

  Device original_device() const {
    return guard_.original_device();
  }

  Device current_device() const {
    return guard_.current_device();
  }

 private:
  c10::impl::InlineDeviceGuard<impl::CUDAGuardImpl> guard_;
};

struct OptionalCUDAGuard {
  explicit OptionalCUDAGuard() : guard_() {}

  explicit OptionalCUDAGuard(std::optional<Device> device_opt)
      : guard_(device_opt) {}

  explicit OptionalCUDAGuard(std::optional<DeviceIndex> device_index_opt)
      : guard_(device_index_opt) {}

  // Copy is not allowed
  OptionalCUDAGuard(const OptionalCUDAGuard&) = delete;
  OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete;

  // See Note [Move construction for RAII guards is tricky]
  OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete;

  // See Note [Move assignment for RAII guards is tricky]
  OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete;
  ~OptionalCUDAGuard() = default;

  void set_device(Device device) {
    guard_.set_device(device);
  }

  void reset_device(Device device) {
    guard_.reset_device(device);
  }

  void set_index(DeviceIndex device_index) {
    guard_.set_index(device_index);
  }

  std::optional<Device> original_device() const {
    return guard_.original_device();
  }

  std::optional<Device> current_device() const {
    return guard_.current_device();
  }

  void reset() {
    guard_.reset();
  }

 private:
  c10::impl::InlineOptionalDeviceGuard<impl::CUDAGuardImpl> guard_;
};

struct CUDAStreamGuard {
  explicit CUDAStreamGuard() = delete;

  explicit CUDAStreamGuard(Stream stream) : guard_(stream) {}
  ~CUDAStreamGuard() = default;

  CUDAStreamGuard(const CUDAStreamGuard&) = delete;
  CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete;

  CUDAStreamGuard(CUDAStreamGuard&& other) = delete;
  CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete;

  void reset_stream(Stream stream) {
    guard_.reset_stream(stream);
  }

  CUDAStream original_stream() const {
    return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream());
  }

  CUDAStream current_stream() const {
    return CUDAStream(CUDAStream::UNCHECKED, guard_.current_stream());
  }

  Device current_device() const {
    return guard_.current_device();
  }

  Device original_device() const {
    return guard_.original_device();
  }

 private:
  c10::impl::InlineStreamGuard<impl::CUDAGuardImpl> guard_;
};

struct OptionalCUDAStreamGuard {
  explicit OptionalCUDAStreamGuard() : guard_() {}

  explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {}

  explicit OptionalCUDAStreamGuard(std::optional<Stream> stream_opt)
      : guard_(stream_opt) {}

  OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete;
  OptionalCUDAStreamGuard& operator=(const OptionalCUDAStreamGuard&) = delete;

  // See Note [Move construction for RAII guards is tricky]
  OptionalCUDAStreamGuard(OptionalCUDAStreamGuard&& other) = delete;

  // See Note [Move assignment for RAII guards is tricky]
  OptionalCUDAStreamGuard& operator=(OptionalCUDAStreamGuard&& other) = delete;
  ~OptionalCUDAStreamGuard() = default;

  void reset_stream(Stream stream) {
    guard_.reset_stream(stream);
  }

  std::optional<CUDAStream> original_stream() const {
    auto r = guard_.original_stream();
    if (r.has_value()) {
      return CUDAStream(CUDAStream::UNCHECKED, r.value());
    } else {
      return std::nullopt;
    }
  }

  std::optional<CUDAStream> current_stream() const {
    auto r = guard_.current_stream();
    if (r.has_value()) {
      return CUDAStream(CUDAStream::UNCHECKED, r.value());
    } else {
      return std::nullopt;
    }
  }

  void reset() {
    guard_.reset();
  }

 private:
  c10::impl::InlineOptionalStreamGuard<impl::CUDAGuardImpl> guard_;
};

struct CUDAMultiStreamGuard {
  explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams)
      : guard_(unwrapStreams(streams)) {}

  CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete;
  CUDAMultiStreamGuard& operator=(const CUDAMultiStreamGuard&) = delete;

  // See Note [Move construction for RAII guards is tricky]
  CUDAMultiStreamGuard(CUDAMultiStreamGuard&& other) = delete;

  // See Note [Move assignment for RAII guards is tricky]
  CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete;
  ~CUDAMultiStreamGuard() = default;

 private:
  c10::impl::InlineMultiStreamGuard<impl::CUDAGuardImpl> guard_;

  static std::vector<Stream> unwrapStreams(ArrayRef<CUDAStream> cudaStreams) {
    std::vector<Stream> streams;
    streams.reserve(cudaStreams.size());
    for (const CUDAStream& cudaStream : cudaStreams) {
      streams.push_back(cudaStream);
    }
    return streams;
  }
};

} // namespace c10::cuda

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources