Program Listing for File CUDAGuard.h¶
↰ Return to documentation for file (c10/cuda/CUDAGuard.h
)
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/InlineStreamGuard.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/impl/CUDAGuardImpl.h>
namespace c10::cuda {
// This code is kind of boilerplatey. See Note [Whither the DeviceGuard
// boilerplate]
struct CUDAGuard {
explicit CUDAGuard() = delete;
explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {}
explicit CUDAGuard(Device device) : guard_(device) {}
// Copy is not allowed
CUDAGuard(const CUDAGuard&) = delete;
CUDAGuard& operator=(const CUDAGuard&) = delete;
// Move is not allowed (there is no uninitialized state)
CUDAGuard(CUDAGuard&& other) = delete;
CUDAGuard& operator=(CUDAGuard&& other) = delete;
~CUDAGuard() = default;
void set_device(Device device) {
guard_.set_device(device);
}
void reset_device(Device device) {
guard_.reset_device(device);
}
void set_index(DeviceIndex device_index) {
guard_.set_index(device_index);
}
Device original_device() const {
return guard_.original_device();
}
Device current_device() const {
return guard_.current_device();
}
private:
c10::impl::InlineDeviceGuard<impl::CUDAGuardImpl> guard_;
};
struct OptionalCUDAGuard {
explicit OptionalCUDAGuard() : guard_() {}
explicit OptionalCUDAGuard(std::optional<Device> device_opt)
: guard_(device_opt) {}
explicit OptionalCUDAGuard(std::optional<DeviceIndex> device_index_opt)
: guard_(device_index_opt) {}
// Copy is not allowed
OptionalCUDAGuard(const OptionalCUDAGuard&) = delete;
OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete;
// See Note [Move construction for RAII guards is tricky]
OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete;
~OptionalCUDAGuard() = default;
void set_device(Device device) {
guard_.set_device(device);
}
void reset_device(Device device) {
guard_.reset_device(device);
}
void set_index(DeviceIndex device_index) {
guard_.set_index(device_index);
}
std::optional<Device> original_device() const {
return guard_.original_device();
}
std::optional<Device> current_device() const {
return guard_.current_device();
}
void reset() {
guard_.reset();
}
private:
c10::impl::InlineOptionalDeviceGuard<impl::CUDAGuardImpl> guard_;
};
struct CUDAStreamGuard {
explicit CUDAStreamGuard() = delete;
explicit CUDAStreamGuard(Stream stream) : guard_(stream) {}
~CUDAStreamGuard() = default;
CUDAStreamGuard(const CUDAStreamGuard&) = delete;
CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete;
CUDAStreamGuard(CUDAStreamGuard&& other) = delete;
CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete;
void reset_stream(Stream stream) {
guard_.reset_stream(stream);
}
CUDAStream original_stream() const {
return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream());
}
CUDAStream current_stream() const {
return CUDAStream(CUDAStream::UNCHECKED, guard_.current_stream());
}
Device current_device() const {
return guard_.current_device();
}
Device original_device() const {
return guard_.original_device();
}
private:
c10::impl::InlineStreamGuard<impl::CUDAGuardImpl> guard_;
};
struct OptionalCUDAStreamGuard {
explicit OptionalCUDAStreamGuard() : guard_() {}
explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {}
explicit OptionalCUDAStreamGuard(std::optional<Stream> stream_opt)
: guard_(stream_opt) {}
OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete;
OptionalCUDAStreamGuard& operator=(const OptionalCUDAStreamGuard&) = delete;
// See Note [Move construction for RAII guards is tricky]
OptionalCUDAStreamGuard(OptionalCUDAStreamGuard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
OptionalCUDAStreamGuard& operator=(OptionalCUDAStreamGuard&& other) = delete;
~OptionalCUDAStreamGuard() = default;
void reset_stream(Stream stream) {
guard_.reset_stream(stream);
}
std::optional<CUDAStream> original_stream() const {
auto r = guard_.original_stream();
if (r.has_value()) {
return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
} else {
return std::nullopt;
}
}
std::optional<CUDAStream> current_stream() const {
auto r = guard_.current_stream();
if (r.has_value()) {
return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
} else {
return std::nullopt;
}
}
void reset() {
guard_.reset();
}
private:
c10::impl::InlineOptionalStreamGuard<impl::CUDAGuardImpl> guard_;
};
struct CUDAMultiStreamGuard {
explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams)
: guard_(unwrapStreams(streams)) {}
CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete;
CUDAMultiStreamGuard& operator=(const CUDAMultiStreamGuard&) = delete;
// See Note [Move construction for RAII guards is tricky]
CUDAMultiStreamGuard(CUDAMultiStreamGuard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete;
~CUDAMultiStreamGuard() = default;
private:
c10::impl::InlineMultiStreamGuard<impl::CUDAGuardImpl> guard_;
static std::vector<Stream> unwrapStreams(ArrayRef<CUDAStream> cudaStreams) {
std::vector<Stream> streams;
streams.reserve(cudaStreams.size());
for (const CUDAStream& cudaStream : cudaStreams) {
streams.push_back(cudaStream);
}
return streams;
}
};
} // namespace c10::cuda