• Docs >
  • Program Listing for File torch_tensorrt.h
Shortcuts

Program Listing for File torch_tensorrt.h

Return to documentation for file (cpp/include/torch_tensorrt/torch_tensorrt.h)

/*
 * Copyright (c) NVIDIA Corporation.
 * All rights reserved.
 *
 * This library is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once

#include <cuda_runtime.h>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "torch/custom_class.h"

#include "torch_tensorrt/macros.h"

// Just include the .h?
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace torch {
namespace jit {
struct Graph;
struct Module;
} // namespace jit
} // namespace torch

namespace c10 {
enum class DeviceType : int8_t;
enum class ScalarType : int8_t;
template <class>
class ArrayRef;
} // namespace c10

namespace nvinfer1 {
class IInt8Calibrator;
}
#endif // DOXYGEN_SHOULD_SKIP_THIS

namespace torch_tensorrt {
class DataType {
 public:
  enum Value : int8_t {
    kLong,
    kDouble,
    kFloat,
    kHalf,
    kChar,
    kInt,
    kBool,
    kUnknown
  };

  DataType() = default;
  constexpr DataType(Value t) : value(t) {}
  TORCHTRT_API DataType(c10::ScalarType t);
  operator Value() const {
    return value;
  }
  explicit operator bool() = delete;
  constexpr bool operator==(DataType other) const {
    return value == other.value;
  }
  constexpr bool operator==(DataType::Value other) const {
    return value == other;
  }
  constexpr bool operator!=(DataType other) const {
    return value != other.value;
  }
  constexpr bool operator!=(DataType::Value other) const {
    return value != other;
  }

 private:
  friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const DataType& dtype);
  Value value;
};

struct Device {
  class DeviceType {
   public:
    enum Value : int8_t {
      kGPU,
      kDLA,
    };

    DeviceType() = default;
    constexpr DeviceType(Value t) : value(t) {}
    DeviceType(c10::DeviceType t);
    operator Value() const {
      return value;
    }
    explicit operator bool() = delete;
    constexpr bool operator==(DeviceType other) const {
      return value == other.value;
    }
    constexpr bool operator!=(DeviceType other) const {
      return value != other.value;
    }

   private:
    Value value;
  };

  DeviceType device_type;

  /*
   * Target gpu id
   */
  int64_t gpu_id;

  /*
   * When using DLA core on NVIDIA AGX platforms gpu_id should be set as Xavier device
   */
  int64_t dla_core;

  bool allow_gpu_fallback;

  Device() : device_type(DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
};

enum class EngineCapability : int8_t {
  kSTANDARD,
  kSAFETY,
  kDLA_STANDALONE,
};

class TensorFormat {
 public:
  enum Value : int8_t {
    kContiguous,
    kChannelsLast,
    kUnknown,
  };

  TensorFormat() = default;
  constexpr TensorFormat(Value t) : value(t) {}
  TORCHTRT_API TensorFormat(at::MemoryFormat t);
  operator Value() const {
    return value;
  }
  explicit operator bool() = delete;
  constexpr bool operator==(TensorFormat other) const {
    return value == other.value;
  }
  constexpr bool operator==(TensorFormat::Value other) const {
    return value == other;
  }
  constexpr bool operator!=(TensorFormat other) const {
    return value != other.value;
  }
  constexpr bool operator!=(TensorFormat::Value other) const {
    return value != other;
  }

 private:
  friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const TensorFormat& format);
  Value value;
};

struct Input : torch::CustomClassHolder {
  std::vector<int64_t> min_shape;
  std::vector<int64_t> opt_shape;
  std::vector<int64_t> max_shape;
  std::vector<int64_t> shape;
  DataType dtype;
  TensorFormat format;
  std::vector<double> tensor_domain;

  Input() {}
  TORCHTRT_API Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(
      std::vector<int64_t> shape,
      std::vector<double> tensor_domain,
      TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(
      std::vector<int64_t> shape,
      DataType dtype,
      std::vector<double> tensor_domain,
      TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(
      c10::ArrayRef<int64_t> shape,
      std::vector<double> tensor_domain,
      TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(
      c10::ArrayRef<int64_t> shape,
      DataType dtype,
      std::vector<double> tensor_domain,
      TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(
      std::vector<int64_t> min_shape,
      std::vector<int64_t> opt_shape,
      std::vector<int64_t> max_shape,
      TensorFormat format = TensorFormat::kContiguous);
  TORCHTRT_API Input(
      std::vector<int64_t> min_shape,
      std::vector<int64_t> opt_shape,
      std::vector<int64_t> max_shape,
      std::vector<double> tensor_domain,
      TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(
      std::vector<int64_t> min_shape,
      std::vector<int64_t> opt_shape,
      std::vector<int64_t> max_shape,
      DataType dtype,
      TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(
      std::vector<int64_t> min_shape,
      std::vector<int64_t> opt_shape,
      std::vector<int64_t> max_shape,
      DataType dtype,
      std::vector<double> tensor_domain,
      TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(
      c10::ArrayRef<int64_t> min_shape,
      c10::ArrayRef<int64_t> opt_shape,
      c10::ArrayRef<int64_t> max_shape,
      TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(
      c10::ArrayRef<int64_t> min_shape,
      c10::ArrayRef<int64_t> opt_shape,
      c10::ArrayRef<int64_t> max_shape,
      std::vector<double> tensor_domain,
      TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(
      c10::ArrayRef<int64_t> min_shape,
      c10::ArrayRef<int64_t> opt_shape,
      c10::ArrayRef<int64_t> max_shape,
      DataType dtype,
      TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(
      c10::ArrayRef<int64_t> min_shape,
      c10::ArrayRef<int64_t> opt_shape,
      c10::ArrayRef<int64_t> max_shape,
      DataType dtype,
      std::vector<double> tensor_domain,
      TensorFormat format = TensorFormat::kContiguous);

  TORCHTRT_API Input(at::Tensor tensor);

 private:
  friend TORCHTRT_API std::ostream& operator<<(std::ostream& os, const Input& input);
  bool input_is_dynamic;
};

struct GraphInputs {
  torch::jit::IValue input_signature; // nested Input, full input spec
  std::vector<Input> inputs; // flatten input spec
};

TORCHTRT_API std::string get_build_info();

TORCHTRT_API void dump_build_info();

TORCHTRT_API void set_device(const int gpu_id);

namespace torchscript {
struct CompileSpec {
  TORCHTRT_API CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);

  TORCHTRT_API CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

  TORCHTRT_API CompileSpec(std::vector<Input> inputs);

  TORCHTRT_API CompileSpec(torch::jit::IValue input_signature);
  // Defaults should reflect TensorRT defaults for BuilderConfig

  GraphInputs graph_inputs;
  std::set<DataType> enabled_precisions = {DataType::kFloat};

  bool disable_tf32 = false;

  bool sparse_weights = false;

  bool refit = false;

  bool debug = false;

  bool truncate_long_and_double = false;

  bool allow_shape_tensors = false;

  Device device;

  EngineCapability capability = EngineCapability::kSTANDARD;

  uint64_t num_avg_timing_iters = 1;

  uint64_t workspace_size = 0;

  uint64_t dla_sram_size = 1048576;

  uint64_t dla_local_dram_size = 1073741824;

  uint64_t dla_global_dram_size = 536870912;

  nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;

  bool require_full_compilation = false;

  uint64_t min_block_size = 3;

  std::vector<std::string> torch_executed_ops;

  std::vector<std::string> torch_executed_modules;
};

TORCHTRT_API bool check_method_operator_support(const torch::jit::Module& module, std::string method_name);

TORCHTRT_API torch::jit::Module compile(const torch::jit::Module& module, CompileSpec info);

TORCHTRT_API std::string convert_method_to_trt_engine(
    const torch::jit::Module& module,
    std::string method_name,
    CompileSpec info);

TORCHTRT_API torch::jit::Module embed_engine_in_new_module(
    const std::string& engine,
    Device device,
    const std::vector<std::string>& input_binding_names = std::vector<std::string>(),
    const std::vector<std::string>& output_binding_names = std::vector<std::string>());
} // namespace torchscript
} // namespace torch_tensorrt

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources