• Docs >
  • Program Listing for File ptq.h
Shortcuts

Program Listing for File ptq.h

Return to documentation for file (cpp/include/torch_tensorrt/ptq.h)

/*
 * Copyright (c) NVIDIA Corporation.
 * All rights reserved.
 *
 * This library is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */
#pragma once

#include <fstream>
#include <iostream>
#include <iterator>
#include <memory>
#include <sstream>
#include <string>
#include <vector>

#include "NvInfer.h"
#include "torch/torch.h"
#include "torch_tensorrt/logging.h"
#include "torch_tensorrt/macros.h"

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace nvinfer1 {
class IInt8Calibrator;
class IInt8EntropyCalibrator2;
} // namespace nvinfer1

namespace torch_tensorrt {
namespace ptq {
TORCHTRT_API bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data);
}
} // namespace torch_tensorrt
#endif // DOXYGEN_SHOULD_SKIP_THIS

namespace torch_tensorrt {
namespace ptq {

template <typename Algorithm, typename DataLoaderUniquePtr>
class Int8Calibrator : Algorithm {
  using DataLoader = typename DataLoaderUniquePtr::element_type;
  using Batch = typename DataLoader::super::BatchType;

 public:
  Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache)
      : dataloader_(dataloader.get()), cache_file_path_(cache_file_path), use_cache_(use_cache) {
    for (auto batch : *dataloader_) {
      batched_data_.push_back(batch.data);
    }
    it_ = batched_data_.begin();
  }

  int getBatchSize() const noexcept override {
    // HACK: Torch-TensorRT only uses explict batch sizing, INT8 Calibrator does not
    // work when reporting the batch size here and having explicity batching.
    // So we just report batch size 1 (warnings will still be printed out).
    return 1;
    // return static_cast<int>(dataloader_->options().batch_size);
  }

  bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override {
    if (it_ != batched_data_.end()) {
      auto status = get_batch_impl(bindings, names, nbBindings, *it_);
      it_ = ++it_;
      return status;
    } else {
      // Reset iterator if incase calibrator is going to be used again
      it_ = batched_data_.begin();
      return false;
    }
  }

  const void* readCalibrationCache(size_t& length) noexcept override {
    if (use_cache_) {
      std::stringstream ss;
      ss << "Reading Calibration Cache from " << cache_file_path_;
      logging::log(logging::Level::kINFO, ss.str());

      cache_.clear();
      std::ifstream input(cache_file_path_, std::ios::binary);
      input >> std::noskipws;
      if (input.good()) {
        std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(), std::back_inserter(cache_));
        logging::log(logging::Level::kDEBUG, "Cache read");
      }
      length = cache_.size();
      return length ? cache_.data() : nullptr;
    }
    return nullptr;
  }

  void writeCalibrationCache(const void* cache, size_t length) noexcept override {
    std::ofstream cache_file(cache_file_path_, std::ios::binary);
    cache_file.write(reinterpret_cast<const char*>(cache), length);
    std::stringstream ss;
    ss << "Saved Calibration Cache to " << cache_file_path_;
    logging::log(logging::Level::kINFO, ss.str());
  }

  operator nvinfer1::IInt8Calibrator*() {
    return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
  }

 private:
  DataLoader* dataloader_;
  const std::string& cache_file_path_;
  size_t cache_size_ = 0;
  bool use_cache_;
  std::vector<char> cache_;
  std::vector<torch::Tensor> batched_data_;
  std::vector<torch::Tensor>::iterator it_;
};

template <typename Algorithm>
class Int8CacheCalibrator : Algorithm {
 public:
  Int8CacheCalibrator(const std::string& cache_file_path) : cache_file_path_(cache_file_path) {}

  int getBatchSize() const noexcept override {
    // HACK: Torch-TensorRT only uses explict batch sizing, INT8 Calibrator does not
    // work when reporting the batch size here and having explicity batching.
    // So we just report batch size 1 (warnings will still be printed out).
    return 1;
  }

  bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override {
    return false;
  }

  const void* readCalibrationCache(size_t& length) noexcept override {
    std::stringstream ss;
    ss << "Reading Calibration Cache from " << cache_file_path_;
    logging::log(logging::Level::kINFO, ss.str());

    cache_.clear();
    std::ifstream input(cache_file_path_, std::ios::binary);
    input >> std::noskipws;
    if (input.good()) {
      std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(), std::back_inserter(cache_));
      logging::log(logging::Level::kDEBUG, "Cache read");
    }
    length = cache_.size();
    return length ? cache_.data() : nullptr;
  }

  void writeCalibrationCache(const void* cache, size_t length) noexcept override {
    std::ofstream cache_file(cache_file_path_, std::ios::binary);
    cache_file.write(reinterpret_cast<const char*>(cache), length);
    std::stringstream ss;
    ss << "Saved Calibration Cache to " << cache_file_path_;
    logging::log(logging::Level::kINFO, ss.str());
  }

  operator nvinfer1::IInt8Calibrator*() {
    return reinterpret_cast<nvinfer1::IInt8Calibrator*>(this);
  }

 private:
  const std::string& cache_file_path_;
  size_t cache_size_ = 0;
  std::vector<char> cache_;
};

template <typename Algorithm = nvinfer1::IInt8EntropyCalibrator2, typename DataLoader>
inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator(
    DataLoader dataloader,
    const std::string& cache_file_path,
    bool use_cache) {
  return Int8Calibrator<Algorithm, DataLoader>(std::move(dataloader), cache_file_path, use_cache);
}

template <typename Algorithm = nvinfer1::IInt8EntropyCalibrator2>
inline Int8CacheCalibrator<Algorithm> make_int8_cache_calibrator(const std::string& cache_file_path) {
  return Int8CacheCalibrator<Algorithm>(cache_file_path);
}

} // namespace ptq
} // namespace torch_tensorrt

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources