torch.Tensor¶
A torch.Tensor
is a multi-dimensional matrix containing elements of
a single data type.
Data types¶
Torch defines tensor types with the following data types:
Data type |
dtype |
---|---|
32-bit floating point |
|
64-bit floating point |
|
16-bit floating point 1 |
|
16-bit floating point 2 |
|
32-bit complex |
|
64-bit complex |
|
128-bit complex |
|
8-bit integer (unsigned) |
|
16-bit integer (unsigned) |
|
32-bit integer (unsigned) |
|
64-bit integer (unsigned) |
|
8-bit integer (signed) |
|
16-bit integer (signed) |
|
32-bit integer (signed) |
|
64-bit integer (signed) |
|
Boolean |
|
quantized 8-bit integer (unsigned) |
|
quantized 8-bit integer (signed) |
|
quantized 32-bit integer (signed) |
|
quantized 4-bit integer (unsigned) 3 |
|
8-bit floating point, e4m3 5 |
|
8-bit floating point, e5m2 5 |
|
- 1
Sometimes referred to as binary16: uses 1 sign, 5 exponent, and 10 significand bits. Useful when precision is important at the expense of range.
- 2
Sometimes referred to as Brain Floating Point: uses 1 sign, 8 exponent, and 7 significand bits. Useful when range is important, since it has the same number of exponent bits as
float32
- 3
quantized 4-bit integer is stored as a 8-bit signed integer. Currently it’s only supported in EmbeddingBag operator.
- 4(1,2,3)
Unsigned types asides from
uint8
are currently planned to only have limited support in eager mode (they primarily exist to assist usage with torch.compile); if you need eager support and the extra range is not needed, we recommend using their signed variants instead. See https://github.com/pytorch/pytorch/issues/58734 for more details.- 5(1,2)
torch.float8_e4m3fn
andtorch.float8_e5m2
implement the spec for 8-bit floating point types from https://arxiv.org/abs/2209.05433. The op support is very limited.
For backwards compatibility, we support the following alternate class names for these data types:
Data type |
CPU tensor |
GPU tensor |
---|---|---|
32-bit floating point |
|
|
64-bit floating point |
|
|
16-bit floating point |
|
|
16-bit floating point |
|
|
8-bit integer (unsigned) |
|
|
8-bit integer (signed) |
|
|
16-bit integer (signed) |
|
|
32-bit integer (signed) |
|
|
64-bit integer (signed) |
|
|
Boolean |
|
|
However, to construct tensors, we recommend using factory functions such as
torch.empty()
with the dtype
argument instead. The
torch.Tensor
constructor is an alias for the default tensor type
(torch.FloatTensor
).
Initializing and basic operations¶
A tensor can be constructed from a Python list
or sequence using the
torch.tensor()
constructor:
>>> torch.tensor([[1., -1.], [1., -1.]])
tensor([[ 1.0000, -1.0000],
[ 1.0000, -1.0000]])
>>> torch.tensor(np.array([[1, 2, 3], [4, 5, 6]]))
tensor([[ 1, 2, 3],
[ 4, 5, 6]])
Warning
torch.tensor()
always copies data
. If you have a Tensor
data
and just want to change its requires_grad
flag, use
requires_grad_()
or
detach()
to avoid a copy.
If you have a numpy array and want to avoid a copy, use
torch.as_tensor()
.
A tensor of specific data type can be constructed by passing a
torch.dtype
and/or a torch.device
to a
constructor or tensor creation op:
>>> torch.zeros([2, 4], dtype=torch.int32)
tensor([[ 0, 0, 0, 0],
[ 0, 0, 0, 0]], dtype=torch.int32)
>>> cuda0 = torch.device('cuda:0')
>>> torch.ones([2, 4], dtype=torch.float64, device=cuda0)
tensor([[ 1.0000, 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000]], dtype=torch.float64, device='cuda:0')
For more information about building Tensors, see Creation Ops
The contents of a tensor can be accessed and modified using Python’s indexing and slicing notation:
>>> x = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> print(x[1][2])
tensor(6)
>>> x[0][1] = 8
>>> print(x)
tensor([[ 1, 8, 3],
[ 4, 5, 6]])
Use torch.Tensor.item()
to get a Python number from a tensor containing a
single value:
>>> x = torch.tensor([[1]])
>>> x
tensor([[ 1]])
>>> x.item()
1
>>> x = torch.tensor(2.5)
>>> x
tensor(2.5000)
>>> x.item()
2.5
For more information about indexing, see Indexing, Slicing, Joining, Mutating Ops
A tensor can be created with requires_grad=True
so that
torch.autograd
records operations on them for automatic differentiation.
>>> x = torch.tensor([[1., -1.], [1., 1.]], requires_grad=True)
>>> out = x.pow(2).sum()
>>> out.backward()
>>> x.grad
tensor([[ 2.0000, -2.0000],
[ 2.0000, 2.0000]])
Each tensor has an associated torch.Storage
, which holds its data.
The tensor class also provides multi-dimensional, strided
view of a storage and defines numeric operations on it.
Note
For more information on tensor views, see Tensor Views.
Note
For more information on the torch.dtype
, torch.device
, and
torch.layout
attributes of a torch.Tensor
, see
Tensor Attributes.
Note
Methods which mutate a tensor are marked with an underscore suffix.
For example, torch.FloatTensor.abs_()
computes the absolute value
in-place and returns the modified tensor, while torch.FloatTensor.abs()
computes the result in a new tensor.
Note
To change an existing tensor’s torch.device
and/or torch.dtype
, consider using
to()
method on the tensor.
Warning
Current implementation of torch.Tensor
introduces memory overhead,
thus it might lead to unexpectedly high memory usage in the applications with many tiny tensors.
If this is your case, consider using one large structure.
Tensor class reference¶
- class torch.Tensor¶
There are a few main ways to create a tensor, depending on your use case.
To create a tensor with pre-existing data, use
torch.tensor()
.To create a tensor with specific size, use
torch.*
tensor creation ops (see Creation Ops).To create a tensor with the same size (and similar types) as another tensor, use
torch.*_like
tensor creation ops (see Creation Ops).To create a tensor with similar type but different size as another tensor, use
tensor.new_*
creation ops.There is a legacy constructor
torch.Tensor
whose use is discouraged. Usetorch.tensor()
instead.
- Tensor.__init__(self, data)¶
This constructor is deprecated, we recommend using
torch.tensor()
instead. What this constructor does depends on the type ofdata
.If
data
is a Tensor, returns an alias to the original Tensor. Unliketorch.tensor()
, this tracks autograd and will propagate gradients to the original Tensor.device
kwarg is not supported for thisdata
type.If
data
is a sequence or nested sequence, create a tensor of the default dtype (typicallytorch.float32
) whose data is the values in the sequences, performing coercions if necessary. Notably, this differs fromtorch.tensor()
in that this constructor will always construct a float tensor, even if the inputs are all integers.If
data
is atorch.Size
, returns an empty tensor of that size.
This constructor does not support explicitly specifying
dtype
ordevice
of the returned tensor. We recommend usingtorch.tensor()
which provides this functionality.- Args:
data (array_like): The tensor to construct from.
- Keyword args:
- device (
torch.device
, optional): the desired device of returned tensor. Default: if None, same
torch.device
as this tensor.
- device (
- Tensor.T¶
Returns a view of this tensor with its dimensions reversed.
If
n
is the number of dimensions inx
,x.T
is equivalent tox.permute(n-1, n-2, ..., 0)
.Warning
The use of
Tensor.T()
on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. ConsidermT
to transpose batches of matrices or x.permute(*torch.arange(x.ndim - 1, -1, -1)) to reverse the dimensions of a tensor.
- Tensor.H¶
Returns a view of a matrix (2-D tensor) conjugated and transposed.
x.H
is equivalent tox.transpose(0, 1).conj()
for complex matrices andx.transpose(0, 1)
for real matrices.See also
mH
: An attribute that also works on batches of matrices.
- Tensor.mT¶
Returns a view of this tensor with the last two dimensions transposed.
x.mT
is equivalent tox.transpose(-2, -1)
.
Returns a new Tensor with |
|
Returns a Tensor of size |
|
Returns a Tensor of size |
|
Returns a Tensor of size |
|
Returns a Tensor of size |
|
Is |
|
Is |
|
Is |
|
Is the |
|
This attribute is |
|
Alias for |
|
Returns a new tensor containing real values of the |
|
Returns a new tensor containing imaginary values of the |
|
Returns the number of bytes consumed by the "view" of elements of the Tensor if the Tensor does not use sparse storage layout. |
|
Alias for |
|
See |
|
In-place version of |
|
Alias for |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
Add a scalar or tensor to |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
Alias for |
|
See |
|
See |
|
See |
|
See |
|
See |
|
Applies the function |
|
See |
|
See |
|
See |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
atan2_(other) -> Tensor |
|
See |
|
See |
|
Computes the gradient of current tensor wrt graph leaves. |
|
See |
|
In-place version of |
|
Returns a result tensor where each is independently sampled from . |
|
Fills each location of |
|
|
|
See |
|
In-place version of |
|
In-place version of |
|
In-place version of |
|
In-place version of |
|
In-place version of |
|
In-place version of |
|
See |
|
|
|
|
|
See |
|
Fills the tensor with numbers drawn from the Cauchy distribution: |
|
See |
|
In-place version of |
|
|
|
See |
|
See |
|
See |
|
In-place version of |
|
Alias for |
|
Alias for |
|
See |
|
Returns a contiguous in memory tensor containing the same data as |
|
Copies the elements from |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
See |
|
See |
|
In-place version of |
|
acosh() -> Tensor |
|
acosh_() -> Tensor |
|
Returns a copy of this object in CPU memory. |
|
See |
|
Returns a copy of this object in CUDA memory. |
|
See |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
|
|
|
|
|
|
Returns the address of the first element of |
|
See |
|
Given a quantized Tensor, dequantize it and return the dequantized float Tensor. |
|
See |
|
Return the number of dense dimensions in a sparse tensor |
|
Returns a new Tensor, detached from the current graph. |
|
Detaches the Tensor from the graph that created it, making it a leaf. |
|
See |
|
See |
|
See |
|
Fill the main diagonal of a tensor that has at least 2-dimensions. |
|
See |
|
See |
|
See |
|
See |
|
In-place version of |
|
Returns the number of dimensions of |
|
Returns a tuple of int describing the dim order or physical layout of |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
|
|
See |
|
Returns the size in bytes of an individual element. |
|
See |
|
In-place version of |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
Returns a new view of the |
|
Expand this tensor to the same size as |
|
Fills |
|
See |
|
In-place version of |
|
Fills |
|
See |
|
See |
|
See |
|
See |
|
|
|
In-place version of |
|
See |
|
In-place version of |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
In-place version of |
|
Fills |
|
See |
|
See |
|
For CUDA tensors, this function returns the device ordinal of the GPU on which the tensor resides. |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
|
|
See |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
Accumulate the elements of |
|
Out-of-place version of |
|
Copies the elements of |
|
Out-of-place version of |
|
Fills the elements of the |
|
Out-of-place version of |
|
Puts values from the tensor |
|
Out-place version of |
|
Accumulate the elements of |
|
Return the indices tensor of a sparse COO tensor. |
|
See |
|
|
|
Given a quantized Tensor, |
|
See |
|
See |
|
See |
|
See |
|
See |
|
See |
|
See |
|
Returns True if |
|
Returns True if the data type of |
|
Returns True if the conjugate bit of |
|
Returns True if the data type of |
|
See |
|
All Tensors that have |
|
Returns true if this tensor resides in pinned memory. |
|
Returns True if both tensors are pointing to the exact same memory (same storage, offset, size and stride). |
|
Checks if tensor is in shared memory. |
|
Returns True if the data type of |
|
Is |
|
See |
|
See |
|
Returns the value of this tensor as a standard Python number. |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
Fills |
|
In-place version of |
|
In-place version of |
|
In-place version of |
|
In-place version of |
|
See |
|
In-place version of |
|
|
|
See |
|
In-place version of |
|
lt(other) -> Tensor |
|
In-place version of |
|
See |
|
See |
|
Makes a |
|
Applies |
|
Copies elements from |
|
Out-of-place version of |
|
Fills elements of |
|
Out-of-place version of |
|
See |
|
Note
|
|
See |
|
See |
|
See |
|
Defines how to transform |
|
See |
|
See |
|
See |
|
See |
|
See |
|
See |
|
See |
|
See |
|
See |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
See |
|
In-place version of |
|
See |
|
See |
|
See |
|
Alias for |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
Alias for |
|
In-place version of |
|
See |
|
See |
|
Fills |
|
See |
|
Returns the tensor as a NumPy |
|
See |
|
See |
|
See |
|
See |
|
Copies the tensor to pinned memory, if it's not already pinned. |
|
See |
|
In-place version of |
|
See |
|
See |
|
In-place version of |
|
See |
|
Copies the elements from |
|
See |
|
Returns the quantization scheme of a given QTensor. |
|
See |
|
Given a Tensor quantized by linear(affine) quantization, returns the scale of the underlying quantizer(). |
|
Given a Tensor quantized by linear(affine) quantization, returns the zero_point of the underlying quantizer(). |
|
Given a Tensor quantized by linear (affine) per-channel quantization, returns a Tensor of scales of the underlying quantizer. |
|
Given a Tensor quantized by linear (affine) per-channel quantization, returns a tensor of zero_points of the underlying quantizer. |
|
Given a Tensor quantized by linear (affine) per-channel quantization, returns the index of dimension on which per-channel quantization is applied. |
|
See |
|
Fills |
|
see |
|
In-place version of |
|
Marks the tensor as having been used by this stream. |
|
Registers a backward hook. |
|
Registers a backward hook that runs after grad accumulation. |
|
In-place version of |
|
See |
|
In-place version of |
|
Repeats this tensor along the specified dimensions. |
|
Is |
|
Change if autograd should record operations on this tensor: sets this tensor's |
|
Returns a tensor with the same data and number of elements as |
|
Returns this tensor as the same shape as |
|
Resizes |
|
Resizes the |
|
Enables this Tensor to have their |
|
Is |
|
See |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
Out-of-place version of |
|
Writes all values from the tensor |
|
Adds all values from the tensor |
|
Out-of-place version of |
|
Reduces all values from the |
|
Out-of-place version of |
|
See |
|
Sets the underlying storage, size, and strides. |
|
Moves the underlying storage to shared memory. |
|
|
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
Returns the size of the |
|
Returns the size of the |
|
See |
|
Alias for |
|
See |
|
See |
|
Returns a new sparse tensor with values from a strided tensor |
|
Return the number of sparse dimensions in a sparse tensor |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
See |
|
Returns the underlying |
|
Returns the underlying |
|
Returns |
|
Returns the type of the underlying storage. |
|
Returns the stride of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
Sum |
|
See |
|
See |
|
See |
|
See |
|
In-place version of |
|
See |
|
Performs Tensor dtype and/or device conversion. |
|
Returns a copy of the tensor in |
|
See |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
Returns the tensor as a (nested) list. |
|
See |
|
Creates a strided copy of |
|
Returns a sparse copy of the tensor. |
|
Convert a tensor to compressed row storage format (CSR). |
|
Convert a tensor to compressed column storage (CSC) format. |
|
Convert a tensor to a block sparse row (BSR) storage format of given blocksize. |
|
Convert a tensor to a block sparse column (BSC) storage format of given blocksize. |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
See |
|
In-place version of |
|
In-place version of |
|
See |
|
In-place version of |
|
Returns the type if dtype is not provided, else casts this object to the specified type. |
|
Returns this tensor cast to the type of the given tensor. |
|
See |
|
See |
|
Returns a view of the original tensor which contains all slices of size |
|
Fills |
|
Returns the unique elements of the input tensor. |
|
Eliminates all but the first element from every consecutive group of equivalent elements. |
|
In-place version of |
|
Return the values tensor of a sparse COO tensor. |
|
See |
|
See |
|
Returns a new tensor with the same data as the |
|
View this tensor as the same size as |
|
See |
|
|
|
See |
|
In-place version of |
|
Returns a copy of this object in XPU memory. |
|
Fills |