Class Tensor¶
Defined in File TensorBody.h
Page Contents
Class Documentation¶
-
class Tensor : public TensorBase¶
Public Types
Public Functions
-
Tensor() = default¶
-
inline explicit Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)¶
-
inline explicit Tensor(const TensorBase &base)¶
-
inline Tensor(TensorBase &&base)¶
-
inline c10::MaybeOwned<Tensor> expect_contiguous(MemoryFormat memory_format = MemoryFormat::Contiguous) const &¶
Should be used if *this can reasonably be expected to be contiguous and performance is important.
Compared to contiguous, it saves a reference count increment/decrement if *this is already contiguous, at the cost in all cases of an extra pointer of stack usage, an extra branch to access, and an extra branch at destruction time.
-
c10::MaybeOwned<Tensor> expect_contiguous(MemoryFormat memory_format = MemoryFormat::Contiguous) && = delete¶
- inline C10_DEPRECATED_MESSAGE ("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().") DeprecatedTypeProperties &type() const
- inline C10_DEPRECATED_MESSAGE ("Tensor.is_variable() is deprecated; everything is a variable now. (If you want to assert that variable has been appropriately handled already, use at::impl::variable_excluded_from_dispatch())") bool is_variable() const noexcept
- template<typename T> inline C10_DEPRECATED_MESSAGE ("Tensor.data<T>() is deprecated. Please use Tensor.data_ptr<T>() instead.") T *data() const
- template<typename T, size_t N, template< typename U > class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> C10_DEPRECATED_MESSAGE ("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") GenericPackedTensorAccessor< T
-
inline index_t packed_accessor() const &¶
- template<typename T, size_t N, template< typename U > class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> C10_DEPRECATED_MESSAGE ("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") GenericPackedTensorAccessor< T
-
index_t packed_accessor() && = delete¶
-
inline void backward(const Tensor &gradient = {}, std::optional<bool> retain_graph = std::nullopt, bool create_graph = false, std::optional<TensorList> inputs = std::nullopt) const¶
Computes the gradient of current tensor with respect to graph leaves.
The graph is differentiated using the chain rule. If the tensor is non-scalar (i.e. its data has more than one element) and requires gradient, the function additionally requires specifying
gradient
. It should be a tensor of matching type and location, that contains the gradient of the differentiated function w.r.t. this Tensor.This function accumulates gradients in the leaves - you might need to zero them before calling it.
- Parameters
gradient – Gradient w.r.t. the tensor. If it is a tensor, it will be automatically converted to a Tensor that does not require grad unless
create_graph
is True. None values can be specified for scalar Tensors or ones that don’t require grad. If a None value would be acceptable then this argument is optional.retain_graph – If
false
, the graph used to compute the grads will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value ofcreate_graph
.create_graph – If
true
, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults tofalse
.inputs – Inputs w.r.t. which the gradient will be accumulated into
at::Tensor::grad
. All other Tensors will be ignored. If not provided, the gradient is accumulated into all the leaf Tensors that were used to compute the current tensor. When inputs are provided and a given input is not a leaf, the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients). It is an implementation detail on which the user should not rely. See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
-
inline Tensor &mutable_grad() const¶
Return a mutable reference to the gradient.
This is conventionally used as
t.grad() = x
to set a gradient to a completely new tensor. Note that this function work with a non-const Tensor and is not thread safe.
-
inline const Tensor &grad() const¶
This function returns an undefined tensor by default and returns a defined tensor the first time a call to
backward()
computes gradients for this Tensor.The attribute will then contain the gradients computed and future calls to
backward()
will accumulate (add) gradients into it.
-
inline const Tensor &_fw_grad(uint64_t level) const¶
This function returns the forward gradient for this Tensor at the given level.
-
inline void _set_fw_grad(const TensorBase &new_grad, uint64_t level, bool is_inplace_op) const¶
This function can be used to set the value of the forward grad.
Note that the given new_grad might not be used directly if it has different metadata (size/stride/storage offset) compared to this Tensor. In that case, new_grad content will be copied into a new Tensor
-
inline void __dispatch__backward(at::TensorList inputs, const ::std::optional<at::Tensor> &gradient = {}, ::std::optional<bool> retain_graph = ::std::nullopt, bool create_graph = false) const¶
-
inline bool __dispatch_is_leaf() const¶
-
inline int64_t __dispatch_output_nr() const¶
-
inline int64_t __dispatch__version() const¶
-
inline void __dispatch_retain_grad() const¶
-
inline bool __dispatch_retains_grad() const¶
-
inline at::Tensor addmv(const at::Tensor &mat, const at::Tensor &vec, const at::Scalar &beta = 1, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor &addmv_(const at::Tensor &mat, const at::Tensor &vec, const at::Scalar &beta = 1, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor addr(const at::Tensor &vec1, const at::Tensor &vec2, const at::Scalar &beta = 1, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor &addr_(const at::Tensor &vec1, const at::Tensor &vec2, const at::Scalar &beta = 1, const at::Scalar &alpha = 1) const¶
-
inline bool allclose(const at::Tensor &other, double rtol = 1e-05, double atol = 1e-08, bool equal_nan = false) const¶
-
inline at::Tensor argmax(::std::optional<int64_t> dim = ::std::nullopt, bool keepdim = false) const¶
-
inline at::Tensor argmin(::std::optional<int64_t> dim = ::std::nullopt, bool keepdim = false) const¶
-
inline at::Tensor as_strided(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional<int64_t> storage_offset = ::std::nullopt) const¶
-
inline at::Tensor as_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional<c10::SymInt> storage_offset = ::std::nullopt) const¶
-
inline const at::Tensor &as_strided_(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional<int64_t> storage_offset = ::std::nullopt) const¶
-
inline const at::Tensor &as_strided__symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional<c10::SymInt> storage_offset = ::std::nullopt) const¶
-
inline at::Tensor baddbmm(const at::Tensor &batch1, const at::Tensor &batch2, const at::Scalar &beta = 1, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor &baddbmm_(const at::Tensor &batch1, const at::Tensor &batch2, const at::Scalar &beta = 1, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor &bernoulli_(const at::Tensor &p, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor &bernoulli_(double p = 0.5, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor bernoulli(double p, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor bincount(const ::std::optional<at::Tensor> &weights = {}, int64_t minlength = 0) const¶
-
inline ::std::vector<at::Tensor> tensor_split_symint(c10::SymIntArrayRef indices, int64_t dim = 0) const¶
-
inline ::std::vector<at::Tensor> tensor_split(const at::Tensor &tensor_indices_or_sections, int64_t dim = 0) const¶
-
inline at::Tensor clamp(const ::std::optional<at::Scalar> &min, const ::std::optional<at::Scalar> &max = ::std::nullopt) const¶
-
inline at::Tensor clamp(const ::std::optional<at::Tensor> &min = {}, const ::std::optional<at::Tensor> &max = {}) const¶
-
inline at::Tensor &clamp_(const ::std::optional<at::Scalar> &min, const ::std::optional<at::Scalar> &max = ::std::nullopt) const¶
-
inline at::Tensor &clamp_(const ::std::optional<at::Tensor> &min = {}, const ::std::optional<at::Tensor> &max = {}) const¶
-
inline at::Tensor clip(const ::std::optional<at::Scalar> &min, const ::std::optional<at::Scalar> &max = ::std::nullopt) const¶
-
inline at::Tensor clip(const ::std::optional<at::Tensor> &min = {}, const ::std::optional<at::Tensor> &max = {}) const¶
-
inline at::Tensor &clip_(const ::std::optional<at::Scalar> &min, const ::std::optional<at::Scalar> &max = ::std::nullopt) const¶
-
inline at::Tensor &clip_(const ::std::optional<at::Tensor> &min = {}, const ::std::optional<at::Tensor> &max = {}) const¶
-
inline at::Tensor __dispatch_contiguous(at::MemoryFormat memory_format = c10::MemoryFormat::Contiguous) const¶
-
inline at::Tensor cov(int64_t correction = 1, const ::std::optional<at::Tensor> &fweights = {}, const ::std::optional<at::Tensor> &aweights = {}) const¶
-
inline at::Tensor cumprod(int64_t dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor &cumprod_(int64_t dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor cumprod(at::Dimname dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor &cumprod_(at::Dimname dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor cumsum(int64_t dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor &cumsum_(int64_t dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor cumsum(at::Dimname dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor &cumsum_(at::Dimname dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor diagonal(at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset = 0) const¶
-
inline at::Tensor diff(int64_t n = 1, int64_t dim = -1, const ::std::optional<at::Tensor> &prepend = {}, const ::std::optional<at::Tensor> &append = {}) const¶
-
inline at::Tensor div(const at::Tensor &other, ::std::optional<c10::string_view> rounding_mode) const¶
-
inline at::Tensor &div_(const at::Tensor &other, ::std::optional<c10::string_view> rounding_mode) const¶
-
inline at::Tensor div(const at::Scalar &other, ::std::optional<c10::string_view> rounding_mode) const¶
-
inline at::Tensor &div_(const at::Scalar &other, ::std::optional<c10::string_view> rounding_mode) const¶
-
inline at::Tensor divide(const at::Tensor &other, ::std::optional<c10::string_view> rounding_mode) const¶
-
inline at::Tensor ÷_(const at::Tensor &other, ::std::optional<c10::string_view> rounding_mode) const¶
-
inline at::Tensor divide(const at::Scalar &other, ::std::optional<c10::string_view> rounding_mode) const¶
-
inline at::Tensor ÷_(const at::Scalar &other, ::std::optional<c10::string_view> rounding_mode) const¶
-
inline at::Tensor new_empty(at::IntArrayRef size, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory) const¶
-
inline at::Tensor new_empty_symint(c10::SymIntArrayRef size, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory) const¶
-
inline at::Tensor new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options = {}) const¶
-
inline at::Tensor new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory) const¶
-
inline at::Tensor new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options = {}) const¶
-
inline at::Tensor new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory) const¶
-
inline at::Tensor new_full(at::IntArrayRef size, const at::Scalar &fill_value, at::TensorOptions options = {}) const¶
-
inline at::Tensor new_full(at::IntArrayRef size, const at::Scalar &fill_value, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory) const¶
-
inline at::Tensor new_full_symint(c10::SymIntArrayRef size, const at::Scalar &fill_value, at::TensorOptions options = {}) const¶
-
inline at::Tensor new_full_symint(c10::SymIntArrayRef size, const at::Scalar &fill_value, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory) const¶
-
inline at::Tensor new_zeros(at::IntArrayRef size, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory) const¶
-
inline at::Tensor new_zeros_symint(c10::SymIntArrayRef size, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory) const¶
-
inline at::Tensor new_ones(at::IntArrayRef size, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory) const¶
-
inline at::Tensor new_ones_symint(c10::SymIntArrayRef size, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory) const¶
-
inline const at::Tensor &resize_(at::IntArrayRef size, ::std::optional<at::MemoryFormat> memory_format = ::std::nullopt) const¶
-
inline const at::Tensor &resize__symint(c10::SymIntArrayRef size, ::std::optional<at::MemoryFormat> memory_format = ::std::nullopt) const¶
-
inline at::Tensor unflatten_symint(at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) const¶
-
inline at::Tensor &index_copy_(int64_t dim, const at::Tensor &index, const at::Tensor &source) const¶
-
inline at::Tensor &index_copy_(at::Dimname dim, const at::Tensor &index, const at::Tensor &source) const¶
-
inline at::Tensor index_copy(at::Dimname dim, const at::Tensor &index, const at::Tensor &source) const¶
-
inline at::Tensor &index_put_(const c10::List<::std::optional<at::Tensor>> &indices, const at::Tensor &values, bool accumulate = false) const¶
-
inline at::Tensor index_put(const c10::List<::std::optional<at::Tensor>> &indices, const at::Tensor &values, bool accumulate = false) const¶
-
inline at::Tensor isclose(const at::Tensor &other, double rtol = 1e-05, double atol = 1e-08, bool equal_nan = false) const¶
-
inline bool is_distributed() const¶
-
inline bool __dispatch_is_floating_point() const¶
-
inline bool __dispatch_is_complex() const¶
-
inline bool __dispatch_is_conj() const¶
-
inline bool __dispatch__is_zerotensor() const¶
-
inline bool __dispatch_is_neg() const¶
-
inline bool is_nonzero() const¶
-
inline bool __dispatch_is_signed() const¶
-
inline bool __dispatch_is_inference() const¶
-
inline ::std::tuple<at::Tensor, at::Tensor> kthvalue(int64_t k, int64_t dim = -1, bool keepdim = false) const¶
-
inline ::std::tuple<at::Tensor, at::Tensor> kthvalue(int64_t k, at::Dimname dim, bool keepdim = false) const¶
-
inline at::Tensor nan_to_num(::std::optional<double> nan = ::std::nullopt, ::std::optional<double> posinf = ::std::nullopt, ::std::optional<double> neginf = ::std::nullopt) const¶
-
inline at::Tensor &nan_to_num_(::std::optional<double> nan = ::std::nullopt, ::std::optional<double> posinf = ::std::nullopt, ::std::optional<double> neginf = ::std::nullopt) const¶
-
inline at::Tensor log_softmax(int64_t dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor log_softmax(at::Dimname dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline ::std::tuple<at::Tensor, at::Tensor> aminmax(::std::optional<int64_t> dim = ::std::nullopt, bool keepdim = false) const¶
-
inline at::Tensor mean(at::OptionalIntArrayRef dim, bool keepdim = false, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor mean(at::DimnameList dim, bool keepdim = false, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor nanmean(at::OptionalIntArrayRef dim = ::std::nullopt, bool keepdim = false, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline bool is_pinned(::std::optional<at::Device> device = ::std::nullopt) const¶
-
inline at::Tensor repeat_interleave(const at::Tensor &repeats, ::std::optional<int64_t> dim = ::std::nullopt, ::std::optional<int64_t> output_size = ::std::nullopt) const¶
-
inline at::Tensor repeat_interleave_symint(const at::Tensor &repeats, ::std::optional<int64_t> dim = ::std::nullopt, ::std::optional<c10::SymInt> output_size = ::std::nullopt) const¶
-
inline at::Tensor repeat_interleave(int64_t repeats, ::std::optional<int64_t> dim = ::std::nullopt, ::std::optional<int64_t> output_size = ::std::nullopt) const¶
-
inline at::Tensor repeat_interleave_symint(c10::SymInt repeats, ::std::optional<int64_t> dim = ::std::nullopt, ::std::optional<c10::SymInt> output_size = ::std::nullopt) const¶
-
inline at::Tensor _reshape_alias_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const¶
-
inline at::Tensor detach() const¶
Returns a new Tensor, detached from the current graph.
The result will never require gradient.
-
inline at::Tensor &detach_() const¶
Detaches the Tensor from the graph that created it, making it a leaf.
Views cannot be detached in-place.
-
inline int64_t size(at::Dimname dim) const¶
-
inline at::Tensor slice(int64_t dim = 0, ::std::optional<int64_t> start = ::std::nullopt, ::std::optional<int64_t> end = ::std::nullopt, int64_t step = 1) const¶
-
inline at::Tensor slice_symint(int64_t dim = 0, ::std::optional<c10::SymInt> start = ::std::nullopt, ::std::optional<c10::SymInt> end = ::std::nullopt, c10::SymInt step = 1) const¶
-
inline at::Tensor slice_inverse(const at::Tensor &src, int64_t dim = 0, ::std::optional<int64_t> start = ::std::nullopt, ::std::optional<int64_t> end = ::std::nullopt, int64_t step = 1) const¶
-
inline at::Tensor slice_inverse_symint(const at::Tensor &src, int64_t dim = 0, ::std::optional<c10::SymInt> start = ::std::nullopt, ::std::optional<c10::SymInt> end = ::std::nullopt, c10::SymInt step = 1) const¶
-
inline at::Tensor slice_scatter(const at::Tensor &src, int64_t dim = 0, ::std::optional<int64_t> start = ::std::nullopt, ::std::optional<int64_t> end = ::std::nullopt, int64_t step = 1) const¶
-
inline at::Tensor slice_scatter_symint(const at::Tensor &src, int64_t dim = 0, ::std::optional<c10::SymInt> start = ::std::nullopt, ::std::optional<c10::SymInt> end = ::std::nullopt, c10::SymInt step = 1) const¶
-
inline at::Tensor select_scatter_symint(const at::Tensor &src, int64_t dim, c10::SymInt index) const¶
-
inline at::Tensor diagonal_scatter(const at::Tensor &src, int64_t offset = 0, int64_t dim1 = 0, int64_t dim2 = 1) const¶
-
inline at::Tensor as_strided_scatter(const at::Tensor &src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional<int64_t> storage_offset = ::std::nullopt) const¶
-
inline at::Tensor as_strided_scatter_symint(const at::Tensor &src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional<c10::SymInt> storage_offset = ::std::nullopt) const¶
-
inline at::Tensor softmax(int64_t dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor softmax(at::Dimname dim, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline ::std::vector<at::Tensor> unsafe_split_symint(c10::SymInt split_size, int64_t dim = 0) const¶
-
inline ::std::vector<at::Tensor> split_symint(c10::SymIntArrayRef split_size, int64_t dim = 0) const¶
-
inline ::std::vector<at::Tensor> unsafe_split_with_sizes(at::IntArrayRef split_sizes, int64_t dim = 0) const¶
-
inline ::std::vector<at::Tensor> unsafe_split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim = 0) const¶
-
inline ::std::vector<at::Tensor> split_with_sizes(at::IntArrayRef split_sizes, int64_t dim = 0) const¶
-
inline ::std::vector<at::Tensor> split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim = 0) const¶
-
inline at::Tensor sspaddmm(const at::Tensor &mat1, const at::Tensor &mat2, const at::Scalar &beta = 1, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor stft(int64_t n_fft, ::std::optional<int64_t> hop_length, ::std::optional<int64_t> win_length, const ::std::optional<at::Tensor> &window, bool normalized, ::std::optional<bool> onesided = ::std::nullopt, ::std::optional<bool> return_complex = ::std::nullopt) const¶
-
inline at::Tensor stft(int64_t n_fft, ::std::optional<int64_t> hop_length = ::std::nullopt, ::std::optional<int64_t> win_length = ::std::nullopt, const ::std::optional<at::Tensor> &window = {}, bool center = true, c10::string_view pad_mode = "reflect", bool normalized = false, ::std::optional<bool> onesided = ::std::nullopt, ::std::optional<bool> return_complex = ::std::nullopt) const¶
-
inline at::Tensor istft(int64_t n_fft, ::std::optional<int64_t> hop_length = ::std::nullopt, ::std::optional<int64_t> win_length = ::std::nullopt, const ::std::optional<at::Tensor> &window = {}, bool center = true, bool normalized = false, ::std::optional<bool> onesided = ::std::nullopt, ::std::optional<int64_t> length = ::std::nullopt, bool return_complex = false) const¶
-
inline int64_t stride(at::Dimname dim) const¶
-
inline at::Tensor sum(at::OptionalIntArrayRef dim, bool keepdim = false, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor sum(at::DimnameList dim, bool keepdim = false, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor nansum(at::OptionalIntArrayRef dim = ::std::nullopt, bool keepdim = false, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor std(at::OptionalIntArrayRef dim = ::std::nullopt, const ::std::optional<at::Scalar> &correction = ::std::nullopt, bool keepdim = false) const¶
-
inline at::Tensor std(at::DimnameList dim, const ::std::optional<at::Scalar> &correction = ::std::nullopt, bool keepdim = false) const¶
-
inline at::Tensor prod(int64_t dim, bool keepdim = false, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor prod(at::Dimname dim, bool keepdim = false, ::std::optional<at::ScalarType> dtype = ::std::nullopt) const¶
-
inline at::Tensor var(at::OptionalIntArrayRef dim = ::std::nullopt, const ::std::optional<at::Scalar> &correction = ::std::nullopt, bool keepdim = false) const¶
-
inline at::Tensor var(at::DimnameList dim, const ::std::optional<at::Scalar> &correction = ::std::nullopt, bool keepdim = false) const¶
-
inline at::Tensor norm(const ::std::optional<at::Scalar> &p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) const¶
-
inline at::Tensor norm(const ::std::optional<at::Scalar> &p, at::IntArrayRef dim, bool keepdim = false) const¶
-
inline at::Tensor norm(const ::std::optional<at::Scalar> &p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) const¶
-
inline at::Tensor norm(const ::std::optional<at::Scalar> &p, at::DimnameList dim, bool keepdim = false) const¶
-
inline const at::Tensor &resize_as_(const at::Tensor &the_template, ::std::optional<at::MemoryFormat> memory_format = ::std::nullopt) const¶
-
inline at::Tensor addmm(const at::Tensor &mat1, const at::Tensor &mat2, const at::Scalar &beta = 1, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor &addmm_(const at::Tensor &mat1, const at::Tensor &mat2, const at::Scalar &beta = 1, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor _addmm_activation(const at::Tensor &mat1, const at::Tensor &mat2, const at::Scalar &beta = 1, const at::Scalar &alpha = 1, bool use_gelu = false) const¶
-
inline const at::Tensor &sparse_resize_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const¶
-
inline const at::Tensor &sparse_resize_and_clear_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const¶
-
inline at::Tensor _sparse_mask_projection(const at::Tensor &mask, bool accumulate_matches = false) const¶
-
inline at::Tensor to_dense(::std::optional<at::ScalarType> dtype = ::std::nullopt, ::std::optional<bool> masked_grad = ::std::nullopt) const¶
-
inline at::Tensor _to_dense(::std::optional<at::ScalarType> dtype = ::std::nullopt, ::std::optional<bool> masked_grad = ::std::nullopt) const¶
-
inline int64_t sparse_dim() const¶
-
inline int64_t _dimI() const¶
-
inline int64_t dense_dim() const¶
-
inline int64_t _dimV() const¶
-
inline int64_t _nnz() const¶
-
inline bool is_coalesced() const¶
-
inline at::Tensor to_sparse(::std::optional<at::Layout> layout = ::std::nullopt, at::OptionalIntArrayRef blocksize = ::std::nullopt, ::std::optional<int64_t> dense_dim = ::std::nullopt) const¶
-
inline at::Tensor _to_sparse(::std::optional<at::Layout> layout = ::std::nullopt, at::OptionalIntArrayRef blocksize = ::std::nullopt, ::std::optional<int64_t> dense_dim = ::std::nullopt) const¶
-
inline at::Tensor to_sparse_bsr(at::IntArrayRef blocksize, ::std::optional<int64_t> dense_dim = ::std::nullopt) const¶
-
inline at::Tensor _to_sparse_bsr(at::IntArrayRef blocksize, ::std::optional<int64_t> dense_dim = ::std::nullopt) const¶
-
inline at::Tensor to_sparse_bsc(at::IntArrayRef blocksize, ::std::optional<int64_t> dense_dim = ::std::nullopt) const¶
-
inline at::Tensor _to_sparse_bsc(at::IntArrayRef blocksize, ::std::optional<int64_t> dense_dim = ::std::nullopt) const¶
-
inline double q_scale() const¶
-
inline int64_t q_zero_point() const¶
-
inline int64_t q_per_channel_axis() const¶
-
inline at::QScheme qscheme() const¶
-
inline at::Tensor _autocast_to_reduced_precision(bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) const¶
-
inline at::Tensor to(at::TensorOptions options = {}, bool non_blocking = false, bool copy = false, ::std::optional<at::MemoryFormat> memory_format = ::std::nullopt) const¶
-
inline at::Tensor to(::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory, bool non_blocking, bool copy, ::std::optional<at::MemoryFormat> memory_format) const¶
-
inline at::Tensor to(at::Device device, at::ScalarType dtype, bool non_blocking = false, bool copy = false, ::std::optional<at::MemoryFormat> memory_format = ::std::nullopt) const¶
-
inline at::Tensor to(at::ScalarType dtype, bool non_blocking = false, bool copy = false, ::std::optional<at::MemoryFormat> memory_format = ::std::nullopt) const¶
-
inline at::Tensor to(const at::Tensor &other, bool non_blocking = false, bool copy = false, ::std::optional<at::MemoryFormat> memory_format = ::std::nullopt) const¶
-
inline at::Scalar item() const¶
-
inline at::Tensor &set_(at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride = {}) const¶
-
inline at::Tensor &set__symint(at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride = {}) const¶
-
inline at::Tensor &set_(const at::Tensor &source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride = {}) const¶
-
inline at::Tensor &set__symint(const at::Tensor &source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride = {}) const¶
-
inline at::Tensor &put_(const at::Tensor &index, const at::Tensor &source, bool accumulate = false) const¶
-
inline at::Tensor put(const at::Tensor &index, const at::Tensor &source, bool accumulate = false) const¶
-
inline at::Tensor &index_add_(int64_t dim, const at::Tensor &index, const at::Tensor &source, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor index_add(int64_t dim, const at::Tensor &index, const at::Tensor &source, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor index_add(at::Dimname dim, const at::Tensor &index, const at::Tensor &source, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor &index_reduce_(int64_t dim, const at::Tensor &index, const at::Tensor &source, c10::string_view reduce, bool include_self = true) const¶
-
inline at::Tensor index_reduce(int64_t dim, const at::Tensor &index, const at::Tensor &source, c10::string_view reduce, bool include_self = true) const¶
-
inline at::Tensor &index_fill_(int64_t dim, const at::Tensor &index, const at::Scalar &value) const¶
-
inline at::Tensor &index_fill_(int64_t dim, const at::Tensor &index, const at::Tensor &value) const¶
-
inline at::Tensor &index_fill_(at::Dimname dim, const at::Tensor &index, const at::Scalar &value) const¶
-
inline at::Tensor &index_fill_(at::Dimname dim, const at::Tensor &index, const at::Tensor &value) const¶
-
inline at::Tensor index_fill(at::Dimname dim, const at::Tensor &index, const at::Scalar &value) const¶
-
inline at::Tensor index_fill(at::Dimname dim, const at::Tensor &index, const at::Tensor &value) const¶
-
inline at::Tensor scatter(int64_t dim, const at::Tensor &index, const at::Tensor &src, c10::string_view reduce) const¶
-
inline at::Tensor &scatter_(int64_t dim, const at::Tensor &index, const at::Tensor &src, c10::string_view reduce) const¶
-
inline at::Tensor scatter(int64_t dim, const at::Tensor &index, const at::Scalar &value, c10::string_view reduce) const¶
-
inline at::Tensor &scatter_(int64_t dim, const at::Tensor &index, const at::Scalar &value, c10::string_view reduce) const¶
-
inline at::Tensor scatter_add(at::Dimname dim, const at::Tensor &index, const at::Tensor &src) const¶
-
inline at::Tensor scatter_reduce(int64_t dim, const at::Tensor &index, const at::Tensor &src, c10::string_view reduce, bool include_self = true) const¶
-
inline at::Tensor &scatter_reduce_(int64_t dim, const at::Tensor &index, const at::Tensor &src, c10::string_view reduce, bool include_self = true) const¶
-
inline at::Tensor &addbmm_(const at::Tensor &batch1, const at::Tensor &batch2, const at::Scalar &beta = 1, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor addbmm(const at::Tensor &batch1, const at::Tensor &batch2, const at::Scalar &beta = 1, const at::Scalar &alpha = 1) const¶
-
inline at::Tensor &random_(int64_t from, ::std::optional<int64_t> to, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor &random_(int64_t to, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor &uniform_(double from = 0, double to = 1, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor &cauchy_(double median = 0, double sigma = 1, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor &log_normal_(double mean = 1, double std = 2, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor &exponential_(double lambd = 1, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor &geometric_(double p, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor cross(const at::Tensor &other, ::std::optional<int64_t> dim = ::std::nullopt) const¶
-
inline at::Tensor take_along_dim(const at::Tensor &indices, ::std::optional<int64_t> dim = ::std::nullopt) const¶
-
inline at::Tensor addcmul(const at::Tensor &tensor1, const at::Tensor &tensor2, const at::Scalar &value = 1) const¶
-
inline at::Tensor &addcmul_(const at::Tensor &tensor1, const at::Tensor &tensor2, const at::Scalar &value = 1) const¶
-
inline at::Tensor addcdiv(const at::Tensor &tensor1, const at::Tensor &tensor2, const at::Scalar &value = 1) const¶
-
inline at::Tensor &addcdiv_(const at::Tensor &tensor1, const at::Tensor &tensor2, const at::Scalar &value = 1) const¶
-
inline ::std::tuple<at::Tensor, at::Tensor> triangular_solve(const at::Tensor &A, bool upper = true, bool transpose = false, bool unitriangular = false) const¶
-
inline ::std::tuple<at::Tensor, at::Tensor, at::Tensor> svd(bool some = true, bool compute_uv = true) const¶
-
inline at::Tensor ormqr(const at::Tensor &input2, const at::Tensor &input3, bool left = true, bool transpose = false) const¶
-
inline at::Tensor multinomial(int64_t num_samples, bool replacement = false, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor histc(int64_t bins = 100, const at::Scalar &min = 0, const at::Scalar &max = 0) const¶
-
inline ::std::tuple<at::Tensor, at::Tensor> histogram(const at::Tensor &bins, const ::std::optional<at::Tensor> &weight = {}, bool density = false) const¶
-
inline ::std::tuple<at::Tensor, at::Tensor> histogram(int64_t bins = 100, ::std::optional<at::ArrayRef<double>> range = ::std::nullopt, const ::std::optional<at::Tensor> &weight = {}, bool density = false) const¶
-
inline at::Tensor quantile(const at::Tensor &q, ::std::optional<int64_t> dim = ::std::nullopt, bool keepdim = false, c10::string_view interpolation = "linear") const¶
-
inline at::Tensor quantile(double q, ::std::optional<int64_t> dim = ::std::nullopt, bool keepdim = false, c10::string_view interpolation = "linear") const¶
-
inline at::Tensor nanquantile(const at::Tensor &q, ::std::optional<int64_t> dim = ::std::nullopt, bool keepdim = false, c10::string_view interpolation = "linear") const¶
-
inline at::Tensor nanquantile(double q, ::std::optional<int64_t> dim = ::std::nullopt, bool keepdim = false, c10::string_view interpolation = "linear") const¶
-
inline ::std::tuple<at::Tensor, at::Tensor> sort(::std::optional<bool> stable, int64_t dim = -1, bool descending = false) const¶
-
inline ::std::tuple<at::Tensor, at::Tensor> sort(::std::optional<bool> stable, at::Dimname dim, bool descending = false) const¶
-
inline ::std::tuple<at::Tensor, at::Tensor> topk(int64_t k, int64_t dim = -1, bool largest = true, bool sorted = true) const¶
-
inline ::std::tuple<at::Tensor, at::Tensor> topk_symint(c10::SymInt k, int64_t dim = -1, bool largest = true, bool sorted = true) const¶
-
inline at::Tensor &normal_(double mean = 0, double std = 1, ::std::optional<at::Generator> generator = ::std::nullopt) const¶
-
inline at::Tensor to_padded_tensor(double padding, at::OptionalIntArrayRef output_size = ::std::nullopt) const¶
-
inline at::Tensor to_padded_tensor_symint(double padding, at::OptionalSymIntArrayRef output_size = ::std::nullopt) const¶
-
inline Tensor to(Device device, caffe2::TypeMeta type_meta, bool non_blocking = false, bool copy = false) const¶
-
inline at::Tensor tensor_data() const¶
NOTE: This is similar to the legacy
.data()
function onVariable
, and is intended to be used from functions that need to access theVariable
’s equivalentTensor
(i.e.Tensor
that shares the same storage and tensor metadata with theVariable
).One notable difference with the legacy
.data()
function is that changes to the returnedTensor
’s tensor metadata (e.g. sizes / strides / storage / storage_offset) will not update the originalVariable
, due to the fact that this function shallow-copies theVariable
’s underlying TensorImpl.
-
inline at::Tensor variable_data() const¶
NOTE:
var.variable_data()
in C++ has the same semantics astensor.data
in Python, which create a newVariable
that shares the same storage and tensor metadata with the originalVariable
, but with a completely new autograd history.NOTE: If we change the tensor metadata (e.g. sizes / strides / storage / storage_offset) of a variable created from
var.variable_data()
, those changes will not update the original variablevar
. In.variable_data()
, we setallow_tensor_metadata_change_
to false to make such changes explicitly illegal, in order to prevent users from changing metadata ofvar.variable_data()
and expecting the original variablevar
to also be updated.
-
template<typename T>
hook_return_void_t<T> register_hook(T &&hook) const¶ Registers a backward hook.
The hook will be called every time a gradient with respect to the Tensor is computed. The hook should have one of the following signature:
hook(Tensor grad) -> Tensor
The hook should not modify its argument, but it can optionally return a new gradient which will be used in place ofhook(Tensor grad) -> void
grad
.This function returns the index of the hook in the list which can be used to remove hook.
Example:
auto v = torch::tensor({0., 0., 0.}, torch::requires_grad()); auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient v.backward(torch::tensor({1., 2., 3.})); // This prints: // ``` // 2 // 4 // 6 // [ CPUFloatType{3} ] // ``` std::cout << v.grad() << std::endl; v.remove_hook(h); // removes the hook
-
template<typename T>
hook_return_var_t<T> register_hook(T &&hook) const¶
-
void _backward(TensorList inputs, const std::optional<Tensor> &gradient, std::optional<bool> keep_graph, bool create_graph) const¶
-
template<typename T>
auto register_hook(T &&hook) const -> Tensor::hook_return_void_t<T>¶
Public Members
- N
- PtrTraits
Public Static Functions
Protected Functions
-
inline explicit Tensor(unsafe_borrow_t, const TensorBase &rhs)¶
Protected Attributes
- friend MaybeOwnedTraits< Tensor >
- friend OptionalTensorRef
- friend TensorRef
-
Tensor() = default¶