#pragma once
#include <ATen/Parallel.h>
#include <ATen/record_function.h>
#include <torch/csrc/api/include/torch/types.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/profiler.h>
// NOLINTBEGIN(misc-unused-using-decls)
namespace torch {
using NoGradGuard = at::NoGradGuard;
using AutoGradMode = at::AutoGradMode;
using at::manual_seed;
// Called during new thread initialization
using at::init_num_threads;
// Returns the number of threads used in parallel region.
using at::get_num_threads;
// Sets the number of threads to be used in parallel region.
using at::set_num_threads;
// Returns the number of threads used for inter-op parallelism.
using at::get_num_interop_threads;
// Sets the number of threads to be used for inter-op parallelism.
using at::set_num_interop_threads;
// Returns true if both t1, t2 are undefined or both are defined and equal
inline bool equal_if_defined(const Tensor& t1, const Tensor& t2) {
return (
(!t1.defined() && !t2.defined()) ||
(t1.defined() && t2.defined() && torch::equal(t1, t2)));
// RecordFunction API
using at::addGlobalCallback;
using at::addThreadLocalCallback;
using at::CallbackHandle;
using at::clearCallbacks;
using at::clearGlobalCallbacks;
using at::clearThreadLocalCallbacks;
using at::DisableRecordFunctionGuard;
using at::enableRecordFunction;
using at::hasCallbacks;
using at::hasGlobalCallbacks;
using at::hasThreadLocalCallbacks;
using at::isRecordFunctionEnabled;
using at::RecordFunction;
using at::RecordFunctionCallback;
using at::RecordFunctionGuard;
using at::removeCallback;
} // namespace torch
// NOLINTEND(misc-unused-using-decls)