Program Listing for File torch_tensorrt.h¶
↰ Return to documentation for file (cpp/include/torch_tensorrt/torch_tensorrt.h
)
/*
* Copyright (c) NVIDIA Corporation.
* All rights reserved.
*
* This library is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <cuda_runtime.h>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "torch/custom_class.h"
#include "torch_tensorrt/macros.h"
// Just include the .h?
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace torch {
namespace jit {
struct Graph;
struct Module;
} // namespace jit
} // namespace torch
namespace c10 {
enum class DeviceType : int8_t;
enum class ScalarType : int8_t;
template <class>
class ArrayRef;
} // namespace c10
namespace nvinfer1 {
class IInt8Calibrator;
}
#endif // DOXYGEN_SHOULD_SKIP_THIS
namespace torch_tensorrt {
class DataType {
public:
enum Value : int8_t {
kLong,
kDouble,
kFloat,
kHalf,
kChar,
kInt,
kBool,
kUnknown
};
DataType() = default;
constexpr DataType(Value t) : value(t) {}
TORCHTRT_API DataType(c10::ScalarType t);
operator Value() const {
return value;
}
explicit operator bool() = delete;
constexpr bool operator==(DataType other) const {
return value == other.value;
}
constexpr bool operator==(DataType::Value other) const {
return value == other;
}
constexpr bool operator!=(DataType other) const {
return value != other.value;
}
constexpr bool operator!=(DataType::Value other) const {
return value != other;
}
private:
friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const DataType& dtype);
Value value;
};
struct Device {
class DeviceType {
public:
enum Value : int8_t {
kGPU,
kDLA,
};
DeviceType() = default;
constexpr DeviceType(Value t) : value(t) {}
DeviceType(c10::DeviceType t);
operator Value() const {
return value;
}
explicit operator bool() = delete;
constexpr bool operator==(DeviceType other) const {
return value == other.value;
}
constexpr bool operator!=(DeviceType other) const {
return value != other.value;
}
private:
Value value;
};
DeviceType device_type;
/*
* Target gpu id
*/
int64_t gpu_id;
/*
* When using DLA core on NVIDIA AGX platforms gpu_id should be set as Xavier device
*/
int64_t dla_core;
bool allow_gpu_fallback;
Device() : device_type(DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
};
enum class EngineCapability : int8_t {
kSTANDARD,
kSAFETY,
kDLA_STANDALONE,
};
class TensorFormat {
public:
enum Value : int8_t {
kContiguous,
kChannelsLast,
kUnknown,
};
TensorFormat() = default;
constexpr TensorFormat(Value t) : value(t) {}
TORCHTRT_API TensorFormat(at::MemoryFormat t);
operator Value() const {
return value;
}
explicit operator bool() = delete;
constexpr bool operator==(TensorFormat other) const {
return value == other.value;
}
constexpr bool operator==(TensorFormat::Value other) const {
return value == other;
}
constexpr bool operator!=(TensorFormat other) const {
return value != other.value;
}
constexpr bool operator!=(TensorFormat::Value other) const {
return value != other;
}
private:
friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const TensorFormat& format);
Value value;
};
struct Input : torch::CustomClassHolder {
std::vector<int64_t> min_shape;
std::vector<int64_t> opt_shape;
std::vector<int64_t> max_shape;
std::vector<int64_t> shape;
DataType dtype;
TensorFormat format;
std::vector<double> tensor_domain;
Input() {}
TORCHTRT_API Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
std::vector<int64_t> shape,
std::vector<double> tensor_domain,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
std::vector<int64_t> shape,
DataType dtype,
std::vector<double> tensor_domain,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
c10::ArrayRef<int64_t> shape,
std::vector<double> tensor_domain,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
c10::ArrayRef<int64_t> shape,
DataType dtype,
std::vector<double> tensor_domain,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
std::vector<int64_t> min_shape,
std::vector<int64_t> opt_shape,
std::vector<int64_t> max_shape,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
std::vector<int64_t> min_shape,
std::vector<int64_t> opt_shape,
std::vector<int64_t> max_shape,
std::vector<double> tensor_domain,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
std::vector<int64_t> min_shape,
std::vector<int64_t> opt_shape,
std::vector<int64_t> max_shape,
DataType dtype,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
std::vector<int64_t> min_shape,
std::vector<int64_t> opt_shape,
std::vector<int64_t> max_shape,
DataType dtype,
std::vector<double> tensor_domain,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
c10::ArrayRef<int64_t> min_shape,
c10::ArrayRef<int64_t> opt_shape,
c10::ArrayRef<int64_t> max_shape,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
c10::ArrayRef<int64_t> min_shape,
c10::ArrayRef<int64_t> opt_shape,
c10::ArrayRef<int64_t> max_shape,
std::vector<double> tensor_domain,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
c10::ArrayRef<int64_t> min_shape,
c10::ArrayRef<int64_t> opt_shape,
c10::ArrayRef<int64_t> max_shape,
DataType dtype,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(
c10::ArrayRef<int64_t> min_shape,
c10::ArrayRef<int64_t> opt_shape,
c10::ArrayRef<int64_t> max_shape,
DataType dtype,
std::vector<double> tensor_domain,
TensorFormat format = TensorFormat::kContiguous);
TORCHTRT_API Input(at::Tensor tensor);
private:
friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const Input& input);
bool input_is_dynamic;
};
struct GraphInputs {
torch::jit::IValue input_signature; // nested Input, full input spec
std::vector<Input> inputs; // flatten input spec
};
TORCHTRT_API std::string get_build_info();
TORCHTRT_API void dump_build_info();
TORCHTRT_API void set_device(const int gpu_id);
namespace torchscript {
struct CompileSpec {
TORCHTRT_API CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
TORCHTRT_API CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
TORCHTRT_API CompileSpec(std::vector<Input> inputs);
TORCHTRT_API CompileSpec(torch::jit::IValue input_signature);
// Defaults should reflect TensorRT defaults for BuilderConfig
GraphInputs graph_inputs;
std::set<DataType> enabled_precisions = {DataType::kFloat};
bool disable_tf32 = false;
bool sparse_weights = false;
bool refit = false;
bool debug = false;
bool truncate_long_and_double = false;
bool allow_shape_tensors = false;
Device device;
EngineCapability capability = EngineCapability::kSTANDARD;
uint64_t num_avg_timing_iters = 1;
uint64_t workspace_size = 0;
uint64_t dla_sram_size = 1048576;
uint64_t dla_local_dram_size = 1073741824;
uint64_t dla_global_dram_size = 536870912;
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
bool require_full_compilation = false;
uint64_t min_block_size = 3;
std::vector<std::string> torch_executed_ops;
std::vector<std::string> torch_executed_modules;
};
TORCHTRT_API bool check_method_operator_support(const torch::jit::Module& module, std::string method_name);
TORCHTRT_API torch::jit::Module compile(const torch::jit::Module& module, CompileSpec info);
TORCHTRT_API std::string convert_method_to_trt_engine(
const torch::jit::Module& module,
std::string method_name,
CompileSpec info);
TORCHTRT_API torch::jit::Module embed_engine_in_new_module(
const std::string& engine,
Device device,
const std::vector<std::string>& input_binding_names = std::vector<std::string>(),
const std::vector<std::string>& output_binding_names = std::vector<std::string>());
} // namespace torchscript
} // namespace torch_tensorrt