.. _program_listing_file_cpp_include_torch_tensorrt_ptq.h: Program Listing for File ptq.h ============================== |exhale_lsh| :ref:`Return to documentation for file ` (``cpp/include/torch_tensorrt/ptq.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: none /* * Copyright (c) NVIDIA Corporation. * All rights reserved. * * This library is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include #include #include #include #include #include "NvInfer.h" #include "torch/torch.h" #include "torch_tensorrt/logging.h" #include "torch_tensorrt/macros.h" #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace nvinfer1 { class IInt8Calibrator; class IInt8EntropyCalibrator2; } // namespace nvinfer1 namespace torch_tensorrt { namespace ptq { TORCHTRT_API bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data); } } // namespace torch_tensorrt #endif // DOXYGEN_SHOULD_SKIP_THIS namespace torch_tensorrt { namespace ptq { template class Int8Calibrator : Algorithm { using DataLoader = typename DataLoaderUniquePtr::element_type; using Batch = typename DataLoader::super::BatchType; public: Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache) : dataloader_(dataloader.get()), cache_file_path_(cache_file_path), use_cache_(use_cache) { for (auto batch : *dataloader_) { batched_data_.push_back(batch.data); } it_ = batched_data_.begin(); } int getBatchSize() const noexcept override { // HACK: Torch-TensorRT only uses explict batch sizing, INT8 Calibrator does not // work when reporting the batch size here and having explicity batching. // So we just report batch size 1 (warnings will still be printed out). return 1; // return static_cast(dataloader_->options().batch_size); } bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override { if (it_ != batched_data_.end()) { auto status = get_batch_impl(bindings, names, nbBindings, *it_); it_ = ++it_; return status; } else { // Reset iterator if incase calibrator is going to be used again it_ = batched_data_.begin(); return false; } } const void* readCalibrationCache(size_t& length) noexcept override { if (use_cache_) { std::stringstream ss; ss << "Reading Calibration Cache from " << cache_file_path_; logging::log(logging::Level::kINFO, ss.str()); cache_.clear(); std::ifstream input(cache_file_path_, std::ios::binary); input >> std::noskipws; if (input.good()) { std::copy(std::istream_iterator(input), std::istream_iterator(), std::back_inserter(cache_)); logging::log(logging::Level::kDEBUG, "Cache read"); } length = cache_.size(); return length ? cache_.data() : nullptr; } return nullptr; } void writeCalibrationCache(const void* cache, size_t length) noexcept override { std::ofstream cache_file(cache_file_path_, std::ios::binary); cache_file.write(reinterpret_cast(cache), length); std::stringstream ss; ss << "Saved Calibration Cache to " << cache_file_path_; logging::log(logging::Level::kINFO, ss.str()); } operator nvinfer1::IInt8Calibrator*() { return reinterpret_cast(this); } private: DataLoader* dataloader_; const std::string& cache_file_path_; size_t cache_size_ = 0; bool use_cache_; std::vector cache_; std::vector batched_data_; std::vector::iterator it_; }; template class Int8CacheCalibrator : Algorithm { public: Int8CacheCalibrator(const std::string& cache_file_path) : cache_file_path_(cache_file_path) {} int getBatchSize() const noexcept override { // HACK: Torch-TensorRT only uses explict batch sizing, INT8 Calibrator does not // work when reporting the batch size here and having explicity batching. // So we just report batch size 1 (warnings will still be printed out). return 1; } bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override { return false; } const void* readCalibrationCache(size_t& length) noexcept override { std::stringstream ss; ss << "Reading Calibration Cache from " << cache_file_path_; logging::log(logging::Level::kINFO, ss.str()); cache_.clear(); std::ifstream input(cache_file_path_, std::ios::binary); input >> std::noskipws; if (input.good()) { std::copy(std::istream_iterator(input), std::istream_iterator(), std::back_inserter(cache_)); logging::log(logging::Level::kDEBUG, "Cache read"); } length = cache_.size(); return length ? cache_.data() : nullptr; } void writeCalibrationCache(const void* cache, size_t length) noexcept override { std::ofstream cache_file(cache_file_path_, std::ios::binary); cache_file.write(reinterpret_cast(cache), length); std::stringstream ss; ss << "Saved Calibration Cache to " << cache_file_path_; logging::log(logging::Level::kINFO, ss.str()); } operator nvinfer1::IInt8Calibrator*() { return reinterpret_cast(this); } private: const std::string& cache_file_path_; size_t cache_size_ = 0; std::vector cache_; }; template inline Int8Calibrator make_int8_calibrator( DataLoader dataloader, const std::string& cache_file_path, bool use_cache) { return Int8Calibrator(std::move(dataloader), cache_file_path, use_cache); } template inline Int8CacheCalibrator make_int8_cache_calibrator(const std::string& cache_file_path) { return Int8CacheCalibrator(cache_file_path); } } // namespace ptq } // namespace torch_tensorrt