torch¶
Tensors¶
-
torch.
is_tensor
(obj)[source]¶ Returns True if obj is a PyTorch tensor.
Parameters: obj (Object) – Object to test
-
torch.
is_storage
(obj)[source]¶ Returns True if obj is a PyTorch storage object.
Parameters: obj (Object) – Object to test
-
torch.
set_default_dtype
(d)[source]¶ Sets the default floating point dtype to
d
. This type will be used as default floating point type for type inference intorch.tensor()
.The default floating point dtype is initially
torch.float32
.Parameters: d ( torch.dtype
) – the floating point dtype to make the defaultExample:
>>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32 torch.float32 >>> torch.set_default_dtype(torch.float64) >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float64
-
torch.
get_default_dtype
() → :class:`torch.dtype`¶ Get the current default floating point
torch.dtype
.Example:
>>> torch.get_default_dtype() # initial default for floating point is torch.float32 torch.float32 >>> torch.set_default_dtype(torch.float64) >>> torch.get_default_dtype() # default is now changed to torch.float64 torch.float64 >>> torch.set_default_tensor_type(torch.FloatTensor) # setting tensor type also affects this >>> torch.get_default_dtype() # changed to torch.float32, the dtype for torch.FloatTensor torch.float32
-
torch.
set_default_tensor_type
(t)[source]¶ Sets the default
torch.Tensor
type to floating point tensor typet
. This type will also be used as default floating point type for type inference intorch.tensor()
.The default floating point tensor type is initially
torch.FloatTensor
.Parameters: t (type or string) – the floating point tensor type or its name Example:
>>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32 torch.float32 >>> torch.set_default_tensor_type(torch.DoubleTensor) >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float64
-
torch.
numel
(input) → int¶ Returns the total number of elements in the
input
tensor.Parameters: input (Tensor) – the input tensor Example:
>>> a = torch.randn(1, 2, 3, 4, 5) >>> torch.numel(a) 120 >>> a = torch.zeros(4,4) >>> torch.numel(a) 16
-
torch.
set_printoptions
(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None)[source]¶ Set options for printing. Items shamelessly taken from NumPy
Parameters: - precision – Number of digits of precision for floating point output (default = 8).
- threshold – Total number of array elements which trigger summarization rather than full repr (default = 1000).
- edgeitems – Number of array items in summary at beginning and end of each dimension (default = 3).
- linewidth – The number of characters per line for the purpose of inserting line breaks (default = 80). Thresholded matrices will ignore this parameter.
- profile – Sane defaults for pretty printing. Can override with any of the above options. (any one of default, short, full)
-
torch.
set_flush_denormal
(mode) → bool¶ Disables denormal floating numbers on CPU.
Returns
True
if your system supports flushing denormal numbers and it successfully configures flush denormal mode.set_flush_denormal()
is only supported on x86 architectures supporting SSE3.Parameters: mode (bool) – Controls whether to enable flush denormal mode or not Example:
>>> torch.set_flush_denormal(True) True >>> torch.tensor([1e-323], dtype=torch.float64) tensor([ 0.], dtype=torch.float64) >>> torch.set_flush_denormal(False) True >>> torch.tensor([1e-323], dtype=torch.float64) tensor(9.88131e-324 * [ 1.0000], dtype=torch.float64)
Creation Ops¶
Note
Random sampling creation ops are listed under Random sampling and
include:
torch.rand()
torch.rand_like()
torch.randn()
torch.randn_like()
torch.randint()
torch.randint_like()
torch.randperm()
You may also use torch.empty()
with the In-place random sampling
methods to create torch.Tensor
s with values sampled from a broader
range of distributions.
-
torch.
tensor
(data, dtype=None, device=None, requires_grad=False) → Tensor¶ Constructs a tensor with
data
.Warning
torch.tensor()
always copiesdata
. If you have a Tensordata
and want to avoid a copy, usetorch.Tensor.requires_grad_()
ortorch.Tensor.detach()
. If you have a NumPyndarray
and want to avoid a copy, usetorch.from_numpy()
.Parameters: - data (array_like) – Initial data for the tensor. Can be a list, tuple,
NumPy
ndarray
, scalar, and other types. - dtype (
torch.dtype
, optional) – the desired data type of returned tensor. Default: if None, infers data type fromdata
. - device (
torch.device
, optional) – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type (seetorch.set_default_tensor_type()
).device
will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. - requires_grad (bool, optional) – If autograd should record operations on the
returned tensor. Default:
False
.
Example:
>>> torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]) tensor([[ 0.1000, 1.2000], [ 2.2000, 3.1000], [ 4.9000, 5.2000]]) >>> torch.tensor([0, 1]) # Type inference on data tensor([ 0, 1]) >>> torch.tensor([[0.11111, 0.222222, 0.3333333]], dtype=torch.float64, device=torch.device('cuda:0')) # creates a torch.cuda.DoubleTensor tensor([[ 0.1111, 0.2222, 0.3333]], dtype=torch.float64, device='cuda:0') >>> torch.tensor(3.14159) # Create a scalar (zero-dimensional tensor) tensor(3.1416) >>> torch.tensor([]) # Create an empty tensor (of size (0,)) tensor([])
- data (array_like) – Initial data for the tensor. Can be a list, tuple,
NumPy
-
torch.
from_numpy
(ndarray) → Tensor¶ Creates a
Tensor
from anumpy.ndarray
.The returned tensor and
ndarray
share the same memory. Modifications to the tensor will be reflected in thendarray
and vice versa. The returned tensor is not resizable.Example:
>>> a = numpy.array([1, 2, 3]) >>> t = torch.from_numpy(a) >>> t tensor([ 1, 2, 3]) >>> t[0] = -1 >>> a array([-1, 2, 3])
-
torch.
zeros
(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a tensor filled with the scalar value 0, with the shape defined by the variable argument
sizes
.Parameters: - sizes (int...) – a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple.
- out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> torch.zeros(2, 3) tensor([[ 0., 0., 0.], [ 0., 0., 0.]]) >>> torch.zeros(5) tensor([ 0., 0., 0., 0., 0.])
-
torch.
zeros_like
(input, dtype=None, layout=None, device=None, requires_grad=False) → Tensor¶ Returns a tensor filled with the scalar value 0, with the same size as
input
.torch.zeros_like(input)
is equivalent totorch.zeros(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)
.Warning
As of 0.4, this function does not support an
out
keyword. As an alternative, the oldtorch.zeros_like(input, out=output)
is equivalent totorch.zeros(input.size(), out=output)
.Parameters: - input (Tensor) – the size of
input
will determine size of the output tensor - dtype (
torch.dtype
, optional) – the desired data type of returned Tensor. - layout (
torch.layout
, optional) – the desired layout of returned tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> input = torch.empty(2, 3) >>> torch.zeros_like(input) tensor([[ 0., 0., 0.], [ 0., 0., 0.]])
- input (Tensor) – the size of
-
torch.
ones
(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument
sizes
.Parameters: - sizes (int...) – a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple.
- out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> torch.ones(2, 3) tensor([[ 1., 1., 1.], [ 1., 1., 1.]]) >>> torch.ones(5) tensor([ 1., 1., 1., 1., 1.])
-
torch.
ones_like
(input, dtype=None, layout=None, device=None, requires_grad=False) → Tensor¶ Returns a tensor filled with the scalar value 1, with the same size as
input
.torch.ones_like(input)
is equivalent totorch.ones(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)
.Warning
As of 0.4, this function does not support an
out
keyword. As an alternative, the oldtorch.ones_like(input, out=output)
is equivalent totorch.ones(input.size(), out=output)
.Parameters: - input (Tensor) – the size of
input
will determine size of the output tensor - dtype (
torch.dtype
, optional) – the desired data type of returned Tensor. - layout (
torch.layout
, optional) – the desired layout of returned tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> input = torch.empty(2, 3) >>> torch.ones_like(input) tensor([[ 1., 1., 1.], [ 1., 1., 1.]])
- input (Tensor) – the size of
-
torch.
arange
(start=0, end, step=1, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a 1-D tensor of size ⌊end−startstep⌋ with values from the interval
[start, end)
taken with common differencestep
beginning from start.Note that non-integer
step
is subject to floating point rounding errors when comparing againstend
; to avoid inconsistency, we advise adding a small epsilon toend
in such cases.outi+1=outi+stepParameters: - start (float) – the starting value for the set of points. Default:
0
. - end (float) – the ending value for the set of points
- step (float) – the gap between each pair of adjacent points. Default:
1
. - out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> torch.arange(5) tensor([ 0., 1., 2., 3., 4.]) >>> torch.arange(1, 4) tensor([ 1., 2., 3.]) >>> torch.arange(1, 2.5, 0.5) tensor([ 1.0000, 1.5000, 2.0000])
- start (float) – the starting value for the set of points. Default:
-
torch.
range
(start=0, end, step=1, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a 1-D tensor of size ⌊end−startstep⌋+1 with values from
start
toend
with stepstep
. Step is the gap between two values in the tensor.outi+1=outi+step.Warning
This function is deprecated in favor of
torch.arange()
.Parameters: - start (float) – the starting value for the set of points. Default:
0
. - end (float) – the ending value for the set of points
- step (float) – the gap between each pair of adjacent points. Default:
1
. - out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> torch.range(1, 4) tensor([ 1., 2., 3., 4.]) >>> torch.range(1, 4, 0.5) tensor([ 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000])
- start (float) – the starting value for the set of points. Default:
-
torch.
linspace
(start, end, steps=100, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a one-dimensional tensor of
steps
equally spaced points betweenstart
andend
.The output tensor is 1-D of size
steps
.Parameters: - start (float) – the starting value for the set of points
- end (float) – the ending value for the set of points
- steps (int) – number of points to sample between
start
andend
. Default:100
. - out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> torch.linspace(3, 10, steps=5) tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) >>> torch.linspace(-10, 10, steps=5) tensor([-10., -5., 0., 5., 10.]) >>> torch.linspace(start=-10, end=10, steps=5) tensor([-10., -5., 0., 5., 10.])
-
torch.
logspace
(start, end, steps=100, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a one-dimensional tensor of
steps
points logarithmically spaced between 10start and 10end.The output tensor is 1-D of size
steps
.Parameters: - start (float) – the starting value for the set of points
- end (float) – the ending value for the set of points
- steps (int) – number of points to sample between
start
andend
. Default:100
. - out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> torch.logspace(start=-10, end=10, steps=5) tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) >>> torch.logspace(start=0.1, end=1.0, steps=5) tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000])
-
torch.
eye
(n, m=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a 2-D tensor with ones on the diagonal and zeros elsewhere.
Parameters: - n (int) – the number of rows
- m (int, optional) – the number of columns with default being
n
- out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Returns: A 2-D tensor with ones on the diagonal and zeros elsewhere
Return type: Example:
>>> torch.eye(3) tensor([[ 1., 0., 0.], [ 0., 1., 0.], [ 0., 0., 1.]])
-
torch.
empty
(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a tensor filled with uninitialized data. The shape of the tensor is defined by the variable argument
sizes
.Parameters: - sizes (int...) – a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple.
- out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> torch.empty(2, 3) tensor(1.00000e-08 * [[ 6.3984, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000]])
-
torch.
empty_like
(input, dtype=None, layout=None, device=None, requires_grad=False) → Tensor¶ Returns an uninitialized tensor with the same size as
input
.torch.empty_like(input)
is equivalent totorch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)
.Parameters: - input (Tensor) – the size of
input
will determine size of the output tensor - dtype (
torch.dtype
, optional) – the desired data type of returned Tensor. - layout (
torch.layout
, optional) – the desired layout of returned tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> input = torch.empty((2,3), dtype=torch.int64) >>> input.new(input.size()) tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], [ 7.5751e+18, 7.1428e+18, 7.5955e+18]])
- input (Tensor) – the size of
-
torch.
full
(size, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a tensor of size
size
filled withfill_value
.Parameters: - size (int...) – a list, tuple, or
torch.Size
of integers defining the shape of the output tensor. - fill_value – the number to fill the output tensor with.
- out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> torch.full((2, 3), 3.141592) tensor([[ 3.1416, 3.1416, 3.1416], [ 3.1416, 3.1416, 3.1416]])
- size (int...) – a list, tuple, or
-
torch.
full_like
(input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a tensor with the same size as
input
filled withfill_value
.torch.full_like(input, fill_value)
is equivalent totorch.full_like(input.size(), fill_value, dtype=input.dtype, layout=input.layout, device=input.device)
.Parameters: - input (Tensor) – the size of
input
will determine size of the output tensor - fill_value – the number to fill the output tensor with.
- dtype (
torch.dtype
, optional) – the desired data type of returned Tensor. - layout (
torch.layout
, optional) – the desired layout of returned tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
- input (Tensor) – the size of
Indexing, Slicing, Joining, Mutating Ops¶
-
torch.
cat
(seq, dim=0, out=None) → Tensor¶ Concatenates the given sequence of
seq
tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.torch.cat()
can be seen as an inverse operation fortorch.split()
andtorch.chunk()
.torch.cat()
can be best understood via examples.Parameters: Example:
>>> x = torch.randn(2, 3) >>> x tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) >>> torch.cat((x, x, x), 0) tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) >>> torch.cat((x, x, x), 1) tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497]])
-
torch.
chunk
(tensor, chunks, dim=0) → List of Tensors¶ Splits a tensor into a specific number of chunks.
Last chunk will be smaller if the tensor size along the given dimension
dim
is not divisible bychunks
.Parameters:
-
torch.
gather
(input, dim, index, out=None) → Tensor¶ Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
If
input
is an n-dimensional tensor with size (x0,x1...,xi−1,xi,xi+1,...,xn−1) anddim
=i, thenindex
must be an n-dimensional tensor with size (x0,x1,...,xi−1,y,xi+1,...,xn−1) where y≥1 andout
will have the same size asindex
.Parameters: Example:
>>> t = torch.tensor([[1,2],[3,4]]) >>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]])) tensor([[ 1, 1], [ 4, 3]])
-
torch.
index_select
(input, dim, index, out=None) → Tensor¶ Returns a new tensor which indexes the
input
tensor along dimensiondim
using the entries inindex
which is a LongTensor.The returned tensor has the same number of dimensions as the original tensor (
input
). Thedim
th dimension has the same size as the length ofindex
; other dimensions have the same size as in the original tensor.Note
The returned tensor does not use the same storage as the original tensor. If
out
has a different shape than expected, we silently change it to the correct shape, reallocating the underlying storage if necessary.Parameters: Example:
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-0.4664, 0.2647, -0.1228, -1.1068], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> indices = torch.tensor([0, 2]) >>> torch.index_select(x, 0, indices) tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> torch.index_select(x, 1, indices) tensor([[ 0.1427, -0.5414], [-0.4664, -0.1228], [-1.1734, 0.7230]])
-
torch.
masked_select
(input, mask, out=None) → Tensor¶ Returns a new 1-D tensor which indexes the
input
tensor according to the binary maskmask
which is a ByteTensor.The shapes of the
mask
tensor and theinput
tensor don’t need to match, but they must be broadcastable.Note
The returned tensor does not use the same storage as the original tensor
Parameters: - input (Tensor) – the input data
- mask (ByteTensor) – the tensor containing the binary mask to index with
- out (Tensor, optional) – the output tensor
Example:
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.3552, -2.3825, -0.8297, 0.3477], [-1.2035, 1.2252, 0.5002, 0.6248], [ 0.1307, -2.0608, 0.1244, 2.0139]]) >>> mask = x.ge(0.5) >>> mask tensor([[ 0, 0, 0, 0], [ 0, 1, 1, 1], [ 0, 0, 0, 1]], dtype=torch.uint8) >>> torch.masked_select(x, mask) tensor([ 1.2252, 0.5002, 0.6248, 2.0139])
-
torch.
nonzero
(input, out=None) → LongTensor¶ Returns a tensor containing the indices of all non-zero elements of
input
. Each row in the result contains the indices of a non-zero element ininput
.If
input
has n dimensions, then the resulting indices tensorout
is of size (z×n), where z is the total number of non-zero elements in theinput
tensor.Parameters: - input (Tensor) – the input tensor
- out (LongTensor, optional) – the output tensor containing indices
Example:
>>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) tensor([[ 0], [ 1], [ 2], [ 4]]) >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], [0.0, 0.4, 0.0, 0.0], [0.0, 0.0, 1.2, 0.0], [0.0, 0.0, 0.0,-0.4]])) tensor([[ 0, 0], [ 1, 1], [ 2, 2], [ 3, 3]])
-
torch.
reshape
(input, shape) → Tensor¶ Returns a tensor with the same data and number of elements as
input
, but with the specified shape. When possible, the returned tensor will be a view ofinput
. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.A single dimension may be -1, in which case it’s inferred from the remaining dimensions and the number of elements in
input
.Parameters: - input (Tensor) – the tensor to be reshaped
- shape (tuple of python:ints) – the new shape
Example:
>>> a = torch.arange(4) >>> torch.reshape(a, (2, 2)) tensor([[ 0., 1.], [ 2., 3.]]) >>> b = torch.tensor([[0, 1], [2, 3]]) >>> torch.reshape(b, (-1,)) tensor([ 0, 1, 2, 3])
-
torch.
split
(tensor, split_size_or_sections, dim=0)[source]¶ Splits the tensor into chunks.
If
split_size_or_sections
is an integer type, thentensor
will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimensiondim= is not divisible by :attr:`split_size
.If
split_size_or_sections
is a list, thentensor
will be split intolen(split_size_or_sections)
chunks with sizes indim
according tosplit_size_or_sections
.Parameters:
-
torch.
squeeze
(input, dim=None, out=None) → Tensor¶ Returns a tensor with all the dimensions of
input
of size 1 removed.For example, if input is of shape: (A×1×B×C×1×D) then the out tensor will be of shape: (A×B×C×D).
When
dim
is given, a squeeze operation is done only in the given dimension. If input is of shape: (A×1×B), squeeze(input, 0) leaves the tensor unchanged, butsqueeze(input, 1)()
will squeeze the tensor to the shape (A×B).Note
As an exception to the above, a 1-dimensional tensor of size 1 will not have its dimensions changed.
Note
The returned tensor shares the storage with the input tensor, so changing the contents of one will change the contents of the other.
Parameters: Example:
>>> x = torch.zeros(2, 1, 2, 1, 2) >>> x.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x) >>> y.size() torch.Size([2, 2, 2]) >>> y = torch.squeeze(x, 0) >>> y.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x, 1) >>> y.size() torch.Size([2, 2, 1, 2])
-
torch.
stack
(seq, dim=0, out=None) → Tensor¶ Concatenates sequence of tensors along a new dimension.
All tensors need to be of the same size.
Parameters:
-
torch.
t
(input, out=None) → Tensor¶ Expects
input
to be a matrix (2-D tensor) and transposes dimensions 0 and 1.Can be seen as a short-hand function for
transpose(input, 0, 1)()
Parameters: Example:
>>> x = torch.randn(2, 3) >>> x tensor([[ 0.4875, 0.9158, -0.5872], [ 0.3938, -0.6929, 0.6932]]) >>> torch.t(x) tensor([[ 0.4875, 0.3938], [ 0.9158, -0.6929], [-0.5872, 0.6932]])
-
torch.
take
(input, indices) → Tensor¶ Returns a new tensor with the elements of
input
at the given indices. The input tensor is treated as if it were viewed as a 1-D tensor. The result takes the same shape as the indices.Parameters: - input (Tensor) – the input tensor
- indices (LongTensor) – the indices into tensor
Example:
>>> src = torch.tensor([[4, 3, 5], [6, 7, 8]]) >>> torch.take(src, torch.tensor([0, 2, 5])) tensor([ 4, 5, 8])
-
torch.
transpose
(input, dim0, dim1, out=None) → Tensor¶ Returns a tensor that is a transposed version of
input
. The given dimensionsdim0
anddim1
are swapped.The resulting
out
tensor shares it’s underlying storage with theinput
tensor, so changing the content of one would change the content of the other.Parameters: Example:
>>> x = torch.randn(2, 3) >>> x tensor([[ 1.0028, -0.9893, 0.5809], [-0.1669, 0.7299, 0.4942]]) >>> torch.transpose(x, 0, 1) tensor([[ 1.0028, -0.1669], [-0.9893, 0.7299], [ 0.5809, 0.4942]])
-
torch.
unbind
(tensor, dim=0)[source]¶ Removes a tensor dimension.
Returns a tuple of all slices along a given dimension, already without it.
Parameters:
-
torch.
unsqueeze
(input, dim, out=None) → Tensor¶ Returns a new tensor with a dimension of size one inserted at the specified position.
The returned tensor shares the same underlying data with this tensor.
A negative dim value within the range [-
input.dim()
,input.dim()
) can be used and will correspond tounsqueeze()
applied atdim
=dim + input.dim() + 1
Parameters: Example:
>>> x = torch.tensor([1, 2, 3, 4]) >>> torch.unsqueeze(x, 0) tensor([[ 1, 2, 3, 4]]) >>> torch.unsqueeze(x, 1) tensor([[ 1], [ 2], [ 3], [ 4]])
-
torch.
where
(condition, x, y) → Tensor¶ Return a tensor of elements selected from either
x
ory
, depending oncondition
.The operation is defined as:
outi={xiif conditioniyiotherwiseNote
The tensors
condition
,x
,y
must be broadcastable.Parameters: - condition (ByteTensor) – When True (nonzero), yield x, otherwise yield y
- x (Tensor) – values selected at indices where
condition
isTrue
- y (Tensor) – values selected at indices where
condition
isFalse
Returns: A tensor of shape equal to the broadcasted shape of
condition
,x
,y
Return type: Example:
>>> x = torch.randn(3, 2) >>> y = torch.ones(3, 2) >>> x tensor([[-0.4620, 0.3139], [ 0.3898, -0.7197], [ 0.0478, -0.1657]]) >>> torch.where(x > 0, x, y) tensor([[ 1.0000, 0.3139], [ 0.3898, 1.0000], [ 0.0478, 1.0000]])
Random sampling¶
-
torch.
manual_seed
(seed)[source]¶ Sets the seed for generating random numbers. Returns a torch._C.Generator object.
Parameters: seed (int) – The desired seed.
-
torch.
initial_seed
()[source]¶ Returns the initial seed for generating random numbers as a Python long.
-
torch.
set_rng_state
(new_state)[source]¶ Sets the random number generator state.
Parameters: new_state (torch.ByteTensor) – The desired state
-
torch.
default_generator
= <torch._C.Generator object>¶
-
torch.
bernoulli
(input, out=None) → Tensor¶ Draws binary random numbers (0 or 1) from a Bernoulli distribution.
The
input
tensor should be a tensor containing probabilities to be used for drawing the binary random number. Hence, all values ininput
have to be in the range: 0≤inputi≤1.The ith element of the output tensor will draw a value 1 according to the ith probability value given in
input
.outi∼Bernoulli(p=inputi)The returned
out
tensor only has values 0 or 1 and is of the same shape asinput
Parameters: Example:
>>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] >>> a tensor([[ 0.1737, 0.0950, 0.3609], [ 0.7148, 0.0289, 0.2676], [ 0.9456, 0.8937, 0.7202]]) >>> torch.bernoulli(a) tensor([[ 1., 0., 0.], [ 0., 0., 0.], [ 1., 1., 1.]]) >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 >>> torch.bernoulli(a) tensor([[ 1., 1., 1.], [ 1., 1., 1.], [ 1., 1., 1.]]) >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 >>> torch.bernoulli(a) tensor([[ 0., 0., 0.], [ 0., 0., 0.], [ 0., 0., 0.]])
-
torch.
multinomial
(input, num_samples, replacement=False, out=None) → LongTensor¶ Returns a tensor where each row contains
num_samples
indices sampled from the multinomial probability distribution located in the corresponding row of tensorinput
.Note
The rows of
input
do not need to sum to one (in which case we use the values as weights), but must be non-negative and have a non-zero sum.Indices are ordered from left to right according to when each was sampled (first samples are placed in first column).
If
input
is a vector,out
is a vector of sizenum_samples
.If
input
is a matrix with m rows,out
is an matrix of shape (m×num_samples).If replacement is
True
, samples are drawn with replacement.If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.
This implies the constraint that
num_samples
must be lower thaninput
length (or number of columns ofinput
if it is a matrix).Parameters: Example:
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> torch.multinomial(weights, 4) tensor([ 1, 2, 0, 0]) >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1])
-
torch.
normal
()¶ -
torch.
normal
(mean, std, out=None) → Tensor
Returns a tensor of random numbers drawn from separate normal distributions whose mean and standard deviation are given.
The
mean
is a tensor with the mean of each output element’s normal distributionThe
std
is a tensor with the standard deviation of each output element’s normal distributionThe shapes of
mean
andstd
don’t need to match, but the total number of elements in each tensor need to be the same.Note
When the shapes do not match, the shape of
mean
is used as the shape for the returned output tensorParameters: Example:
>>> torch.normal(mean=torch.arange(1, 11), std=torch.arange(1, 0, -0.1)) tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, 8.0505, 8.1408, 9.0563, 10.0566])
-
torch.
normal
(mean=0.0, std, out=None) → Tensor
Similar to the function above, but the means are shared among all drawn elements.
Parameters: Example:
>>> torch.normal(mean=0.5, std=torch.arange(1, 6)) tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303])
-
torch.
normal
(mean, std=1.0, out=None) → Tensor
Similar to the function above, but the standard-deviations are shared among all drawn elements.
Parameters: Example:
>>> torch.normal(mean=torch.arange(1, 6)) tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361])
-
-
torch.
rand
(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a tensor filled with random numbers from a uniform distribution on the interval [0,1)
The shape of the tensor is defined by the variable argument
sizes
.Parameters: - sizes (int...) – a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple.
- {out} –
- {dtype} –
- {layout} –
- {device} –
- {requires_grad} –
Example:
>>> torch.rand(4) tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) >>> torch.rand(2, 3) tensor([[ 0.8237, 0.5781, 0.6879], [ 0.3816, 0.7249, 0.0998]])
-
torch.
rand_like
(input, dtype=None, layout=None, device=None, requires_grad=False) → Tensor¶ Returns a tensor with the same size as
input
that is filled with random numbers from a uniform distribution on the interval [0,1).torch.rand_like(input)
is equivalent totorch.rand(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)
.Parameters: - input (Tensor) – the size of
input
will determine size of the output tensor - dtype (
torch.dtype
, optional) – the desired data type of returned Tensor. - layout (
torch.layout
, optional) – the desired layout of returned tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
- input (Tensor) – the size of
-
torch.
randint
(low=0, high, size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a tensor filled with random integers generated uniformly between
low
(inclusive) andhigh
(exclusive).The shape of the tensor is defined by the variable argument
size
.Parameters: - low (int, optional) – Lowest integer to be drawn from the distribution. Default: 0.
- high (int) – One above the highest integer to be drawn from the distribution.
- size (tuple) – a tuple defining the shape of the output tensor.
- out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> torch.randint(3, 5, (3,)) tensor([ 4., 3., 4.]) >>> torch.randint(3, 10, (2,2), dtype=torch.long) tensor([[ 8, 3], [ 3, 9]]) >>> torch.randint(3, 10, (2,2)) tensor([[ 4., 5.], [ 6., 7.]])
-
torch.
randint_like
(input, low=0, high, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a tensor with the same shape as Tensor
input
filled with random integers generated uniformly betweenlow
(inclusive) andhigh
(exclusive).Parameters: - input (Tensor) – the size of
input
will determine size of the output tensor - low (int, optional) – Lowest integer to be drawn from the distribution. Default: 0.
- high (int) – One above the highest integer to be drawn from the distribution.
- dtype (
torch.dtype
, optional) – the desired data type of returned Tensor. - layout (
torch.layout
, optional) – the desired layout of returned tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
- input (Tensor) – the size of
-
torch.
randn
(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor¶ Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 (also called the standard normal distribution).
outi∼N(0,1)The shape of the tensor is defined by the variable argument
sizes
.Parameters: - sizes (int...) – a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple.
- out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> torch.randn(4) tensor([-2.1436, 0.9966, 2.3426, -0.6366]) >>> torch.randn(2, 3) tensor([[ 1.5954, 2.8929, -1.0923], [ 1.1719, -0.4709, -0.1996]])
-
torch.
randn_like
(input, dtype=None, layout=None, device=None, requires_grad=False) → Tensor¶ Returns a tensor with the same size as
input
that is filled with random numbers from a normal distribution with mean 0 and variance 1.torch.randn_like(input)
is equivalent totorch.randn(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)
.Parameters: - input (Tensor) – the size of
input
will determine size of the output tensor - dtype (
torch.dtype
, optional) – the desired data type of returned Tensor. - layout (
torch.layout
, optional) – the desired layout of returned tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
- input (Tensor) – the size of
-
torch.
randperm
(n, out=None, dtype=torch.int64, layout=torch.strided, device=None, requires_grad=False) → LongTensor¶ Returns a random permutation of integers from
0
ton - 1
.Parameters: - n (int) – the upper bound (exclusive)
- out (Tensor, optional) – the output tensor
- dtype (
torch.dtype
, optional) – the desired data type of returned tensor. Default:torch.int64
. - layout (
torch.layout
, optional) – the desired layout of returned Tensor. - device (
torch.device
, optional) – the desired device of returned tensor. - requires_grad (bool, optional) – If autograd should record operations on the
Example:
>>> torch.randperm(4) tensor([ 2, 1, 0, 3])
In-place random sampling¶
There are a few more in-place random sampling functions defined on Tensors as well. Click through to refer to their documentation:
torch.Tensor.bernoulli_()
- in-place version oftorch.bernoulli()
torch.Tensor.cauchy_()
- numbers drawn from the Cauchy distributiontorch.Tensor.exponential_()
- numbers drawn from the exponential distributiontorch.Tensor.geometric_()
- elements drawn from the geometric distributiontorch.Tensor.log_normal_()
- samples from the log-normal distributiontorch.Tensor.normal_()
- in-place version oftorch.normal()
torch.Tensor.random_()
- numbers sampled from the discrete uniform distributiontorch.Tensor.uniform_()
- numbers sampled from the continuous uniform distribution
Serialization¶
-
torch.
save
(obj, f, pickle_module=pickle, pickle_protocol=2)[source]¶ Saves an object to a disk file.
See also: Recommended approach for saving a model
Parameters: - obj – saved object
- f – a file-like object (has to implement write and flush) or a string containing a file name
- pickle_module – module used for pickling metadata and objects
- pickle_protocol – can be specified to override the default protocol
Warning
If you are using Python 2, torch.save does NOT support StringIO.StringIO as a valid file-like object. This is because the write method should return the number of bytes written; StringIO.write() does not do this.
Please use something like io.BytesIO instead.
Example
>>> # Save to file >>> x = torch.tensor([0, 1, 2, 3, 4]) >>> torch.save(x, 'tensor.pt') >>> # Save to io.BytesIO buffer >>> buffer = io.BytesIO() >>> torch.save(x, buffer)
-
torch.
load
(f, map_location=None, pickle_module=pickle)[source]¶ Loads an object saved with
torch.save()
from a file.torch.load()
uses Python’s unpickling facilities but treats storages, which underlie tensors, specially. They are first deserialized on the CPU and are then moved to the device they were saved from. If this fails (e.g. because the run time system doesn’t have certain devices), an exception is raised. However, storages can be dynamically remapped to an alternative set of devices using the map_location argument.If map_location is a callable, it will be called once for each serialized storage with two arguments: storage and location. The storage argument will be the initial deserialization of the storage, residing on the CPU. Each serialized storage has a location tag associated with it which identifies the device it was saved from, and this tag is the second argument passed to map_location. The builtin location tags are ‘cpu’ for CPU tensors and ‘cuda:device_id’ (e.g. ‘cuda:2’) for CUDA tensors. map_location should return either None or a storage. If map_location returns a storage, it will be used as the final deserialized object, already moved to the right device. Otherwise, torch.load will fall back to the default behavior, as if map_location wasn’t specified.
If map_location is a string, it should be a device tag, where all tensors should be loaded.
Otherwise, if map_location is a dict, it will be used to remap location tags appearing in the file (keys), to ones that specify where to put the storages (values).
User extensions can register their own location tags and tagging and deserialization methods using register_package.
Parameters: - f – a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name
- map_location – a function, string or a dict specifying how to remap storage locations
- pickle_module – module used for unpickling metadata and objects (has to match the pickle_module used to serialize file)
Example
>>> torch.load('tensors.pt') # Load all tensors onto the CPU >>> torch.load('tensors.pt', map_location='cpu') # Load all tensors onto the CPU, using a function >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) # Load all tensors onto GPU 1 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) # Map tensors from GPU 1 to GPU 0 >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) # Load tensor from io.BytesIO object >>> with open('tensor.pt') as f: buffer = io.BytesIO(f.read()) >>> torch.load(buffer)
Parallelism¶
-
torch.
get_num_threads
() → int¶ Gets the number of OpenMP threads used for parallelizing CPU operations
-
torch.
set_num_threads
(int)¶ Sets the number of OpenMP threads used for parallelizing CPU operations
Locally disabling gradient computation¶
The context managers torch.no_grad()
, torch.enable_grad()
, and
torch.set_grad_enabled()
are helpful for locally disabling and enabling
gradient computation. See Locally disabling gradient computation for more details on
their usage.
Examples:
>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True) # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False
Math operations¶
Pointwise Ops¶
-
torch.
abs
(input, out=None) → Tensor¶ Computes the element-wise absolute value of the given
input
tensor.outi=|inputi|Parameters: Example:
>>> torch.abs(torch.tensor([-1, -2, 3])) tensor([ 1, 2, 3])
-
torch.
acos
(input, out=None) → Tensor¶ Returns a new tensor with the arccosine of the elements of
input
.outi=cos−1(inputi)Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.3348, -0.5889, 0.2005, -0.1584]) >>> torch.acos(a) tensor([ 1.2294, 2.2004, 1.3690, 1.7298])
-
torch.
add
()¶ -
torch.
add
(input, value, out=None)
Adds the scalar
value
to each element of the inputinput
and returns a new resulting tensor.out=input+valueIf
input
is of type FloatTensor or DoubleTensor,value
must be a real number, otherwise it should be an integer.Parameters: - input (Tensor) – the input tensor
- value (Number) – the number to be added to each element of
input
Keyword Arguments: out (Tensor, optional) – the output tensor
Example:
>>> a = torch.randn(4) >>> a tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) >>> torch.add(a, 20) tensor([ 20.0202, 21.0985, 21.3506, 19.3944])
-
torch.
add
(input, value=1, other, out=None)
Each element of the tensor
other
is multiplied by the scalarvalue
and added to each element of the tensorinput
. The resulting tensor is returned.The shapes of
input
andother
must be broadcastable.out=input+value×otherIf
other
is of type FloatTensor or DoubleTensor,value
must be a real number, otherwise it should be an integer.Parameters: Keyword Arguments: out (Tensor, optional) – the output tensor
Example:
>>> a = torch.randn(4) >>> a tensor([-0.9732, -0.3497, 0.6245, 0.4022]) >>> b = torch.randn(4, 1) >>> b tensor([[ 0.3743], [-1.7724], [-0.5811], [-0.8017]]) >>> torch.add(a, 10, b) tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], [-18.6971, -18.0736, -17.0994, -17.3216], [ -6.7845, -6.1610, -5.1868, -5.4090], [ -8.9902, -8.3667, -7.3925, -7.6147]])
-
-
torch.
addcdiv
(tensor, value=1, tensor1, tensor2, out=None) → Tensor¶ Performs the element-wise division of
tensor1
bytensor2
, multiply the result by the scalarvalue
and add it totensor
.outi=tensori+value×tensor1itensor2iThe shapes of
tensor
,tensor1
, andtensor2
must be broadcastable.For inputs of type FloatTensor or DoubleTensor,
value
must be a real number, otherwise an integer.Parameters: Example:
>>> t = torch.randn(1, 3) >>> t1 = torch.randn(3, 1) >>> t2 = torch.randn(1, 3) >>> torch.addcdiv(t, 0.1, t1, t2) tensor([[-0.2312, -3.6496, 0.1312], [-1.0428, 3.4292, -0.1030], [-0.5369, -0.9829, 0.0430]])
-
torch.
addcmul
(tensor, value=1, tensor1, tensor2, out=None) → Tensor¶ Performs the element-wise multiplication of
tensor1
bytensor2
, multiply the result by the scalarvalue
and add it totensor
.outi=tensori+value×tensor1i×tensor2iThe shapes of
tensor
,tensor1
, andtensor2
must be broadcastable.For inputs of type FloatTensor or DoubleTensor,
value
must be a real number, otherwise an integer.Parameters: Example:
>>> t = torch.randn(1, 3) >>> t1 = torch.randn(3, 1) >>> t2 = torch.randn(1, 3) >>> torch.addcmul(t, 0.1, t1, t2) tensor([[-0.8635, -0.6391, 1.6174], [-0.7617, -0.5879, 1.7388], [-0.8353, -0.6249, 1.6511]])
-
torch.
asin
(input, out=None) → Tensor¶ Returns a new tensor with the arcsine of the elements of
input
.outi=sin−1(inputi)Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([-0.5962, 1.4985, -0.4396, 1.4525]) >>> torch.asin(a) tensor([-0.6387, nan, -0.4552, nan])
-
torch.
atan
(input, out=None) → Tensor¶ Returns a new tensor with the arctangent of the elements of
input
.outi=tan−1(inputi)Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.2341, 0.2539, -0.6256, -0.6448]) >>> torch.atan(a) tensor([ 0.2299, 0.2487, -0.5591, -0.5727])
-
torch.
atan2
(input1, input2, out=None) → Tensor¶ Returns a new tensor with the arctangent of the elements of
input1
andinput2
.The shapes of
input1
andinput2
must be broadcastable.Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.9041, 0.0196, -0.3108, -2.4423]) >>> torch.atan2(a, torch.randn(4)) tensor([ 0.9833, 0.0811, -1.9743, -1.4151])
-
torch.
ceil
(input, out=None) → Tensor¶ Returns a new tensor with the ceil of the elements of
input
, the smallest integer greater than or equal to each element.outi=⌈inputi⌉=⌊inputi⌋+1Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([-0.6341, -1.4208, -1.0900, 0.5826]) >>> torch.ceil(a) tensor([-0., -1., -1., 1.])
-
torch.
clamp
(input, min, max, out=None) → Tensor¶ Clamp all elements in
input
into the range [min
,max
] and return a resulting tensor:yi={minif xi<minxiif min≤xi≤maxmaxif xi>maxIf
input
is of type FloatTensor or DoubleTensor, argsmin
andmax
must be real numbers, otherwise they should be integers.Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([-1.7120, 0.1734, -0.0478, -0.0922]) >>> torch.clamp(a, min=-0.5, max=0.5) tensor([-0.5000, 0.1734, -0.0478, -0.0922])
-
torch.
clamp
(input, *, min, out=None) → Tensor
Clamps all elements in
input
to be larger or equalmin
.If
input
is of type FloatTensor or DoubleTensor,value
should be a real number, otherwise it should be an integer.Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([-0.0299, -2.3184, 2.1593, -0.8883]) >>> torch.clamp(a, min=0.5) tensor([ 0.5000, 0.5000, 2.1593, 0.5000])
-
torch.
clamp
(input, *, max, out=None) → Tensor
Clamps all elements in
input
to be smaller or equalmax
.If
input
is of type FloatTensor or DoubleTensor,value
should be a real number, otherwise it should be an integer.Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.0753, -0.4702, -0.4599, 0.1899]) >>> torch.clamp(a, max=0.5) tensor([ 0.0753, -0.4702, -0.4599, 0.1899])
-
-
torch.
cos
(input, out=None) → Tensor¶ Returns a new tensor with the cosine of the elements of
input
.outi=cos(inputi)Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 1.4309, 1.2706, -0.8562, 0.9796]) >>> torch.cos(a) tensor([ 0.1395, 0.2957, 0.6553, 0.5574])
-
torch.
cosh
(input, out=None) → Tensor¶ Returns a new tensor with the hyperbolic cosine of the elements of
input
.outi=cosh(inputi)Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.1632, 1.1835, -0.6979, -0.7325]) >>> torch.cosh(a) tensor([ 1.0133, 1.7860, 1.2536, 1.2805])
-
torch.
div
()¶ -
torch.
div
(input, value, out=None) → Tensor
Divides each element of the input
input
with the scalarvalue
and returns a new resulting tensor.outi=inputivalueIf
input
is of type FloatTensor or DoubleTensor,value
should be a real number, otherwise it should be an integerParameters: Example:
>>> a = torch.randn(5) >>> a tensor([ 0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) >>> torch.div(a, 0.5) tensor([ 0.7620, 2.5548, -0.5944, -0.7439, 0.9275])
-
torch.
div
(input, other, out=None) → Tensor
Each element of the tensor
input
is divided by each element of the tensorother
. The resulting tensor is returned. The shapes ofinput
andother
must be broadcastable.outi=inputiotheriParameters: Example:
>>> a = torch.randn(4, 4) >>> a tensor([[-0.3711, -1.9353, -0.4605, -0.2917], [ 0.1815, -1.0111, 0.9805, -1.5923], [ 0.1062, 1.4581, 0.7759, -1.2344], [-0.1830, -0.0313, 1.1908, -1.4757]]) >>> b = torch.randn(4) >>> b tensor([ 0.8032, 0.2930, -0.8113, -0.2308]) >>> torch.div(a, b) tensor([[-0.4620, -6.6051, 0.5676, 1.2637], [ 0.2260, -3.4507, -1.2086, 6.8988], [ 0.1322, 4.9764, -0.9564, 5.3480], [-0.2278, -0.1068, -1.4678, 6.3936]])
-
-
torch.
erf
(tensor, out=None) → Tensor¶ Computes the error function of each element. The error function is defined as follows:
erf(x)=2√π∫x0e−t2dtParameters: Example:
>>> torch.erf(torch.tensor([0, -1., 10.])) tensor([ 0.0000, -0.8427, 1.0000])
-
torch.
erfinv
(tensor, out=None) → Tensor¶ Computes the inverse error function of each element. The inverse error function is defined in the range (−1,1) as:
erfinv(erf(x))=xParameters: Example:
>>> torch.erfinv(torch.tensor([0, 0.5, -1.])) tensor([ 0.0000, 0.4769, -inf])
-
torch.
exp
(tensor, out=None) → Tensor¶ Returns a new tensor with the exponential of the elements of
input
.yi=exiParameters: Example:
>>> torch.exp(torch.tensor([0, math.log(2)])) tensor([ 1., 2.])
-
torch.
expm1
(tensor, out=None) → Tensor¶ Returns a new tensor with the exponential of the elements minus 1 of
input
.yi=exi−1Parameters: Example:
>>> torch.expm1(torch.tensor([0, math.log(2)])) tensor([ 0., 1.])
-
torch.
floor
(input, out=None) → Tensor¶ Returns a new tensor with the floor of the elements of
input
, the largest integer less than or equal to each element.outi=⌊inputi⌋Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([-0.8166, 1.5308, -0.2530, -0.2091]) >>> torch.floor(a) tensor([-1., 1., -1., -1.])
-
torch.
fmod
(input, divisor, out=None) → Tensor¶ Computes the element-wise remainder of division.
The dividend and divisor may contain both for integer and floating point numbers. The remainder has the same sign as the dividend
input
.When
divisor
is a tensor, the shapes ofinput
anddivisor
must be broadcastable.Parameters: Example:
>>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) tensor([-1., -0., -1., 1., 0., 1.]) >>> torch.fmod(torch.tensor([1., 2, 3, 4, 5]), 1.5) tensor([ 1.0000, 0.5000, 0.0000, 1.0000, 0.5000])
-
torch.
frac
(tensor, out=None) → Tensor¶ Computes the fractional portion of each element in
tensor
.outi=inputi−⌊inputi⌋Example:
>>> torch.frac(torch.tensor([1, 2.5, -3.2])) tensor([ 0.0000, 0.5000, -0.2000])
-
torch.
lerp
(start, end, weight, out=None)¶ Does a linear interpolation of two tensors
start
andend
based on a scalarweight
and returns the resultingout
tensor.outi=starti+weight×(endi−starti)The shapes of
start
andend
must be broadcastable.Parameters: Example:
>>> start = torch.arange(1, 5) >>> end = torch.empty(4).fill_(10) >>> start tensor([ 1., 2., 3., 4.]) >>> end tensor([ 10., 10., 10., 10.]) >>> torch.lerp(start, end, 0.5) tensor([ 5.5000, 6.0000, 6.5000, 7.0000])
-
torch.
log
(input, out=None) → Tensor¶ Returns a new tensor with the natural logarithm of the elements of
input
.yi=loge(xi)Parameters: Example:
>>> a = torch.randn(5) >>> a tensor([-0.7168, -0.5471, -0.8933, -1.4428, -0.1190]) >>> torch.log(a) tensor([ nan, nan, nan, nan, nan])
-
torch.
log10
(input, out=None) → Tensor¶ Returns a new tensor with the logarithm to the base 10 of the elements of
input
.yi=log10(xi)Parameters: Example:
>>> a = torch.rand(5) >>> a tensor([ 0.5224, 0.9354, 0.7257, 0.1301, 0.2251]) >>> torch.log10(a) tensor([-0.2820, -0.0290, -0.1392, -0.8857, -0.6476])
-
torch.
log1p
(input, out=None) → Tensor¶ Returns a new tensor with the natural logarithm of (1 +
input
).yi=loge(xi+1)Note
This function is more accurate than
torch.log()
for small values ofinput
Parameters: Example:
>>> a = torch.randn(5) >>> a tensor([-1.0090, -0.9923, 1.0249, -0.5372, 0.2492]) >>> torch.log1p(a) tensor([ nan, -4.8653, 0.7055, -0.7705, 0.2225])
-
torch.
log2
(input, out=None) → Tensor¶ Returns a new tensor with the logarithm to the base 2 of the elements of
input
.yi=log2(xi)Parameters: Example:
>>> a = torch.rand(5) >>> a tensor([ 0.8419, 0.8003, 0.9971, 0.5287, 0.0490]) >>> torch.log2(a) tensor([-0.2483, -0.3213, -0.0042, -0.9196, -4.3504])
-
torch.
mul
()¶ -
torch.
mul
(input, value, out=None)
Multiplies each element of the input
input
with the scalarvalue
and returns a new resulting tensor.outi=value×inputiIf
input
is of type FloatTensor or DoubleTensor,value
should be a real number, otherwise it should be an integerParameters: Example:
>>> a = torch.randn(3) >>> a tensor([ 0.2015, -0.4255, 2.6087]) >>> torch.mul(a, 100) tensor([ 20.1494, -42.5491, 260.8663])
-
torch.
mul
(input, other, out=None)
Each element of the tensor
input
is multiplied by each element of the Tensorother
. The resulting tensor is returned.The shapes of
input
andother
must be broadcastable.outi=inputi×otheriParameters: Example:
>>> a = torch.randn(4, 1) >>> a tensor([[ 1.1207], [-0.3137], [ 0.0700], [ 0.8378]]) >>> b = torch.randn(1, 4) >>> b tensor([[ 0.5146, 0.1216, -0.5244, 2.2382]]) >>> torch.mul(a, b) tensor([[ 0.5767, 0.1363, -0.5877, 2.5083], [-0.1614, -0.0382, 0.1645, -0.7021], [ 0.0360, 0.0085, -0.0367, 0.1567], [ 0.4312, 0.1019, -0.4394, 1.8753]])
-
-
torch.
neg
(input, out=None) → Tensor¶ Returns a new tensor with the negative of the elements of
input
.out=−1×inputParameters: Example:
>>> a = torch.randn(5) >>> a tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) >>> torch.neg(a) tensor([-0.0090, 0.2262, 0.0682, 0.2866, -0.3940])
-
torch.
pow
()¶ -
torch.
pow
(input, exponent, out=None) → Tensor
Takes the power of each element in
input
withexponent
and returns a tensor with the result.exponent
can be either a singlefloat
number or a Tensor with the same number of elements asinput
.When
exponent
is a scalar value, the operation applied is:outi=xexponentiWhen
exponent
is a tensor, the operation applied is:outi=xexponentiiWhen
exponent
is a tensor, the shapes ofinput
andexponent
must be broadcastable.Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) >>> torch.pow(a, 2) tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) >>> exp = torch.arange(1, 5) >>> a = torch.arange(1, 5) >>> a tensor([ 1., 2., 3., 4.]) >>> exp tensor([ 1., 2., 3., 4.]) >>> torch.pow(a, exp) tensor([ 1., 4., 27., 256.])
-
torch.
pow
(base, input, out=None) → Tensor
base
is a scalarfloat
value, andinput
is a tensor. The returned tensorout
is of the same shape asinput
The operation applied is:
outi=baseinputiParameters: Example:
>>> exp = torch.arange(1, 5) >>> base = 2 >>> torch.pow(base, exp) tensor([ 2., 4., 8., 16.])
-
-
torch.
reciprocal
(input, out=None) → Tensor¶ Returns a new tensor with the reciprocal of the elements of
input
outi=1inputiParameters: Example:
>>> a = torch.randn(4) >>> a tensor([-0.4595, -2.1219, -1.4314, 0.7298]) >>> torch.reciprocal(a) tensor([-2.1763, -0.4713, -0.6986, 1.3702])
-
torch.
remainder
(input, divisor, out=None) → Tensor¶ Computes the element-wise remainder of division.
The divisor and dividend may contain both for integer and floating point numbers. The remainder has the same sign as the divisor.
When
divisor
is a tensor, the shapes ofinput
anddivisor
must be broadcastable.Parameters: Example:
>>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) tensor([ 1., 0., 1., 1., 0., 1.]) >>> torch.remainder(torch.tensor([1., 2, 3, 4, 5]), 1.5) tensor([ 1.0000, 0.5000, 0.0000, 1.0000, 0.5000])
See also
torch.fmod()
, which computes the element-wise remainder of division equivalently to the C library functionfmod()
.
-
torch.
round
(input, out=None) → Tensor¶ Returns a new tensor with each of the elements of
input
rounded to the closest integer.Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.9920, 0.6077, 0.9734, -1.0362]) >>> torch.round(a) tensor([ 1., 1., 1., -1.])
-
torch.
rsqrt
(input, out=None) → Tensor¶ Returns a new tensor with the reciprocal of the square-root of each of the elements of
input
.outi=1√inputiParameters: Example:
>>> a = torch.randn(4) >>> a tensor([-0.0370, 0.2970, 1.5420, -0.9105]) >>> torch.rsqrt(a) tensor([ nan, 1.8351, 0.8053, nan])
-
torch.
sigmoid
(input, out=None) → Tensor¶ Returns a new tensor with the sigmoid of the elements of
input
.outi=11+e−inputiParameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.9213, 1.0887, -0.8858, -1.7683]) >>> torch.sigmoid(a) tensor([ 0.7153, 0.7481, 0.2920, 0.1458])
-
torch.
sign
(input, out=None) → Tensor¶ Returns a new tensor with the sign of the elements of
input
.Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 1.0382, -1.4526, -0.9709, 0.4542]) >>> torch.sign(a) tensor([ 1., -1., -1., 1.])
-
torch.
sin
(input, out=None) → Tensor¶ Returns a new tensor with the sine of the elements of
input
.outi=sin(inputi)Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([-0.5461, 0.1347, -2.7266, -0.2746]) >>> torch.sin(a) tensor([-0.5194, 0.1343, -0.4032, -0.2711])
-
torch.
sinh
(input, out=None) → Tensor¶ Returns a new tensor with the hyperbolic sine of the elements of
input
.outi=sinh(inputi)Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.5380, -0.8632, -0.1265, 0.9399]) >>> torch.sinh(a) tensor([ 0.5644, -0.9744, -0.1268, 1.0845])
-
torch.
sqrt
(input, out=None) → Tensor¶ Returns a new tensor with the square-root of the elements of
input
.outi=√inputiParameters: Example:
>>> a = torch.randn(4) >>> a tensor([-2.0755, 1.0226, 0.0831, 0.4806]) >>> torch.sqrt(a) tensor([ nan, 1.0112, 0.2883, 0.6933])
-
torch.
tan
(input, out=None) → Tensor¶ Returns a new tensor with the tangent of the elements of
input
.outi=tan(inputi)Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([-1.2027, -1.7687, 0.4412, -1.3856]) >>> torch.tan(a) tensor([-2.5930, 4.9859, 0.4722, -5.3366])
-
torch.
tanh
(input, out=None) → Tensor¶ Returns a new tensor with the hyperbolic tangent of the elements of
input
.outi=tanh(inputi)Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.8986, -0.7279, 1.1745, 0.2611]) >>> torch.tanh(a) tensor([ 0.7156, -0.6218, 0.8257, 0.2553])
-
torch.
trunc
(input, out=None) → Tensor¶ Returns a new tensor with the truncated integer values of the elements of
input
.Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 3.4742, 0.5466, -0.8008, -0.9079]) >>> torch.trunc(a) tensor([ 3., 0., -0., -0.])
Reduction Ops¶
-
torch.
argmax
(input, dim=None, keepdim=False)[source]¶ Returns the indices of the maximum values of a tensor across a dimension.
This is the second value returned by
torch.max()
. See its documentation for the exact semantics of this method.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a, dim=1) tensor([ 0, 2, 0, 1])
-
torch.
argmin
(input, dim=None, keepdim=False)[source]¶ Returns the indices of the minimum values of a tensor across a dimension.
This is the second value returned by
torch.min()
. See its documentation for the exact semantics of this method.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a tensor([[ 0.1139, 0.2254, -0.1381, 0.3687], [ 1.0100, -1.1975, -0.0102, -0.4732], [-0.9240, 0.1207, -0.7506, -1.0213], [ 1.7809, -1.2960, 0.9384, 0.1438]]) >>> torch.argmin(a, dim=1) tensor([ 2, 1, 3, 1])
-
torch.
cumprod
(input, dim, out=None) → Tensor¶ Returns the cumulative product of elements of
input
in the dimensiondim
.For example, if
input
is a vector of size N, the result will also be a vector of size N, with elements.yi=x1×x2×x3×⋯×xiParameters: Example:
>>> a = torch.randn(10) >>> a tensor([ 0.6001, 0.2069, -0.1919, 0.9792, 0.6727, 1.0062, 0.4126, -0.2129, -0.4206, 0.1968]) >>> torch.cumprod(a, dim=0) tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0158, -0.0065, 0.0014, -0.0006, -0.0001]) >>> a[5] = 0.0 >>> torch.cumprod(a, dim=0) tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, 0.0000, -0.0000, -0.0000])
-
torch.
cumsum
(input, dim, out=None) → Tensor¶ Returns the cumulative sum of elements of
input
in the dimensiondim
.For example, if
input
is a vector of size N, the result will also be a vector of size N, with elements.yi=x1+x2+x3+⋯+xiParameters: Example:
>>> a = torch.randn(10) >>> a tensor([-0.8286, -0.4890, 0.5155, 0.8443, 0.1865, -0.1752, -2.0595, 0.1850, -1.1571, -0.4243]) >>> torch.cumsum(a, dim=0) tensor([-0.8286, -1.3175, -0.8020, 0.0423, 0.2289, 0.0537, -2.0058, -1.8209, -2.9780, -3.4022])
-
torch.
dist
(input, other, p=2) → Tensor¶ Returns the p-norm of (
input
-other
)The shapes of
input
andother
must be broadcastable.Parameters: Example:
>>> x = torch.randn(4) >>> x tensor([-1.5393, -0.8675, 0.5916, 1.6321]) >>> y = torch.randn(4) >>> y tensor([ 0.0967, -1.0511, 0.6295, 0.8360]) >>> torch.dist(x, y, 3.5) tensor(1.6727) >>> torch.dist(x, y, 3) tensor(1.6973) >>> torch.dist(x, y, 0) tensor(inf) >>> torch.dist(x, y, 1) tensor(2.6537)
-
torch.
mean
()¶ -
torch.
mean
(input) → Tensor
Returns the mean value of all elements in the
input
tensor.Parameters: input (Tensor) – the input tensor Example:
>>> a = torch.randn(1, 3) >>> a tensor([[ 0.2294, -0.5481, 1.3288]]) >>> torch.mean(a) tensor(0.3367)
-
torch.
mean
(input, dim, keepdim=False, out=None) → Tensor
Returns the mean value of each row of the
input
tensor in the given dimensiondim
.If
keepdim
isTrue
, the output tensor is of the same size asinput
except in the dimensiondim
where it is of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the output tensor having 1 fewer dimension.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a tensor([[-0.3841, 0.6320, 0.4254, -0.7384], [-0.9644, 1.0131, -0.6549, -1.4279], [-0.2951, -1.3350, -0.7694, 0.5600], [ 1.0842, -0.9580, 0.3623, 0.2343]]) >>> torch.mean(a, 1) tensor([-0.0163, -0.5085, -0.4599, 0.1807]) >>> torch.mean(a, 1, True) tensor([[-0.0163], [-0.5085], [-0.4599], [ 0.1807]])
-
-
torch.
median
()¶ -
torch.
median
(input) → Tensor
Returns the median value of all elements in the
input
tensor.Parameters: input (Tensor) – the input tensor Example:
>>> a = torch.randn(1, 3) >>> a tensor([[ 1.5219, -1.5212, 0.2202]]) >>> torch.median(a) tensor(0.2202)
-
torch.
median
(input, dim=-1, keepdim=False, values=None, indices=None) -> (Tensor, LongTensor)
Returns the median value of each row of the
input
tensor in the given dimensiondim
. Also returns the index location of the median value as a LongTensor.By default,
dim
is the last dimension of theinput
tensor.If
keepdim
isTrue
, the output tensors are of the same size asinput
except in the dimensiondim
where they are of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the outputs tensor having 1 fewer dimension thaninput
.Parameters: Example:
>>> a = torch.randn(4, 5) >>> a tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) >>> torch.median(a, 1) (tensor([-0.3982, 0.2270, 0.2488, 0.4742]), tensor([ 1, 4, 4, 3]))
-
-
torch.
mode
(input, dim=-1, keepdim=False, values=None, indices=None) -> (Tensor, LongTensor)¶ Returns the mode value of each row of the
input
tensor in the given dimensiondim
. Also returns the index location of the mode value as a LongTensor.By default,
dim
is the last dimension of theinput
tensor.If
keepdim
isTrue
, the output tensors are of the same size asinput
except in the dimensiondim
where they are of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the output tensors having 1 fewer dimension thaninput
.Note
This function is not defined for
torch.cuda.Tensor
yet.Parameters: Example:
>>> a = torch.randn(4, 5) >>> a tensor([[-1.2808, -1.0966, -1.5946, -0.1148, 0.3631], [ 1.1395, 1.1452, -0.6383, 0.3667, 0.4545], [-0.4061, -0.3074, 0.4579, -1.3514, 1.2729], [-1.0130, 0.3546, -1.4689, -0.1254, 0.0473]]) >>> torch.mode(a, 1) (tensor([-1.5946, -0.6383, -1.3514, -1.4689]), tensor([ 2, 2, 3, 2]))
-
torch.
norm
()¶ -
torch.
norm
(input, p=2) → Tensor
Returns the p-norm of the
input
tensor.||x||p=p√xp1+xp2+…+xpNParameters: Example:
>>> a = torch.randn(1, 3) >>> a tensor([[-0.5192, -1.0782, -1.0448]]) >>> torch.norm(a, 3) tensor(1.3633)
-
torch.
norm
(input, p, dim, keepdim=False, out=None) → Tensor
Returns the p-norm of each row of the
input
tensor in the given dimensiondim
.If
keepdim
isTrue
, the output tensor is of the same size asinput
except in the dimensiondim
where it is of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the output tensor having 1 fewer dimension thaninput
.Parameters: Example:
>>> a = torch.randn(4, 2) >>> a tensor([[ 2.1983, 0.4141], [ 0.8734, 1.9710], [-0.7778, 0.7938], [-0.1342, 0.7347]]) >>> torch.norm(a, 2, 1) tensor([ 2.2369, 2.1558, 1.1113, 0.7469]) >>> torch.norm(a, 0, 1, True) tensor([[ 2.], [ 2.], [ 2.], [ 2.]])
-
-
torch.
prod
()¶ -
torch.
prod
(input) → Tensor
Returns the product of all elements in the
input
tensor.Parameters: input (Tensor) – the input tensor Example:
>>> a = torch.randn(1, 3) >>> a tensor([[-0.8020, 0.5428, -1.5854]]) >>> torch.prod(a) tensor(0.6902)
-
torch.
prod
(input, dim, keepdim=False, out=None) → Tensor
Returns the product of each row of the
input
tensor in the given dimensiondim
.If
keepdim
isTrue
, the output tensor is of the same size asinput
except in the dimensiondim
where it is of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the output tensor having 1 fewer dimension thaninput
.Parameters: Example:
>>> a = torch.randn(4, 2) >>> a tensor([[ 0.5261, -0.3837], [ 1.1857, -0.2498], [-1.1646, 0.0705], [ 1.1131, -1.0629]]) >>> torch.prod(a, 1) tensor([-0.2018, -0.2962, -0.0821, -1.1831])
-
-
torch.
std
()¶ -
torch.
std
(input, unbiased=True) → Tensor
Returns the standard-deviation of all elements in the
input
tensor.If
unbiased
isFalse
, then the standard-deviation will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.Parameters: Example:
>>> a = torch.randn(1, 3) >>> a tensor([[-0.8166, -1.3802, -0.3560]]) >>> torch.std(a) tensor(0.5130)
-
torch.
std
(input, dim, keepdim=False, unbiased=True, out=None) → Tensor
Returns the standard-deviation of each row of the
input
tensor in the given dimensiondim
.If
keepdim
isTrue
, the output tensor is of the same size asinput
except in the dimensiondim
where it is of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the output tensor having 1 fewer dimension thaninput
.If
unbiased
isFalse
, then the standard-deviation will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a tensor([[ 0.2035, 1.2959, 1.8101, -0.4644], [ 1.5027, -0.3270, 0.5905, 0.6538], [-1.5745, 1.3330, -0.5596, -0.6548], [ 0.1264, -0.5080, 1.6420, 0.1992]]) >>> torch.std(a, dim=1) tensor([ 1.0311, 0.7477, 1.2204, 0.9087])
-
-
torch.
sum
()¶ -
torch.
sum
(input) → Tensor
Returns the sum of all elements in the
input
tensor.Parameters: input (Tensor) – the input tensor Example:
>>> a = torch.randn(1, 3) >>> a tensor([[ 0.1133, -0.9567, 0.2958]]) >>> torch.sum(a) tensor(-0.5475)
-
torch.
sum
(input, dim, keepdim=False, out=None) → Tensor
Returns the sum of each row of the
input
tensor in the given dimensiondim
.If
keepdim
isTrue
, the output tensor is of the same size asinput
except in the dimensiondim
where it is of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the output tensor having 1 fewer dimension thaninput
.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], [-0.2993, 0.9138, 0.9337, -1.6864], [ 0.1132, 0.7892, -0.1003, 0.5688], [ 0.3637, -0.9906, -0.4752, -1.5197]]) >>> torch.sum(a, 1) tensor([-0.4598, -0.1381, 1.3708, -2.6217])
-
-
torch.
unique
(input, sorted=False, return_inverse=False)[source]¶ Returns the unique scalar elements of the input tensor as a 1-D tensor.
Parameters: Returns: A tensor or a tuple of tensors containing
- output (Tensor): the output list of unique scalar elements.
- inverse_indices (Tensor): (optional) if
return_inverse
is True, there will be a 2nd returned tensor (same shape as input) representing the indices for where elements in the original input map to in the output; otherwise, this function will only return a single tensor.
Return type: Example:
>>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long)) >>> output tensor([ 2, 3, 1]) >>> output, inverse_indices = torch.unique( torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True) >>> output tensor([ 1, 2, 3]) >>> inverse_indices tensor([ 0, 2, 1, 2]) >>> output, inverse_indices = torch.unique( torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True) >>> output tensor([ 1, 2, 3]) >>> inverse_indices tensor([[ 0, 2], [ 1, 2]])
-
torch.
var
()¶ -
torch.
var
(input, unbiased=True) → Tensor
Returns the variance of all elements in the
input
tensor.If
unbiased
isFalse
, then the variance will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.Parameters: Example:
>>> a = torch.randn(1, 3) >>> a tensor([[-0.3425, -1.2636, -0.4864]]) >>> torch.var(a) tensor(0.2455)
-
torch.
var
(input, dim, keepdim=False, unbiased=True, out=None) → Tensor
Returns the variance of each row of the
input
tensor in the given dimensiondim
.If
keepdim
isTrue
, the output tensors are of the same size asinput
except in the dimensiondim
where they are of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the outputs tensor having 1 fewer dimension thaninput
.If
unbiased
isFalse
, then the variance will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a tensor([[-0.3567, 1.7385, -1.3042, 0.7423], [ 1.3436, -0.1015, -0.9834, -0.8438], [ 0.6056, 0.1089, -0.3112, -1.4085], [-0.7700, 0.6074, -0.1469, 0.7777]]) >>> torch.var(a, 1) tensor([ 1.7444, 1.1363, 0.7356, 0.5112])
-
Comparison Ops¶
-
torch.
eq
(input, other, out=None) → Tensor¶ Computes element-wise equality
The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
Parameters: Returns: A
torch.ByteTensor
containing a 1 at each location where comparison is trueReturn type: Example:
>>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[ 1, 0], [ 0, 1]], dtype=torch.uint8)
-
torch.
equal
(tensor1, tensor2) → bool¶ True
if two tensors have the same size and elements,False
otherwise.Example:
>>> torch.equal(torch.tensor([1, 2]), torch.tensor([1, 2])) True
-
torch.
ge
(input, other, out=None) → Tensor¶ Computes input≥other element-wise.
The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
Parameters: Returns: A
torch.ByteTensor
containing a 1 at each location where comparison is trueReturn type: Example:
>>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[ 1, 1], [ 0, 1]], dtype=torch.uint8)
-
torch.
gt
(input, other, out=None) → Tensor¶ Computes input>other element-wise.
The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
Parameters: Returns: A
torch.ByteTensor
containing a 1 at each location where comparison is trueReturn type: Example:
>>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[ 0, 1], [ 0, 0]], dtype=torch.uint8)
-
torch.
isnan
(tensor)[source]¶ Returns a new tensor with boolean elements representing if each element is NaN or not.
Parameters: tensor (Tensor) – A tensor to check Returns: A torch.ByteTensor
containing a 1 at each location of NaN elements.Return type: Tensor Example:
>>> torch.isnan(torch.tensor([1, float('nan'), 2])) tensor([ 0, 1, 0], dtype=torch.uint8)
-
torch.
kthvalue
(input, k, dim=None, keepdim=False, out=None) -> (Tensor, LongTensor)¶ Returns the
k
th smallest element of the giveninput
tensor along a given dimension.If
dim
is not given, the last dimension of the input is chosen.A tuple of (values, indices) is returned, where the indices is the indices of the kth-smallest element in the original input tensor in dimension dim.
If
keepdim
isTrue
, both thevalues
andindices
tensors are the same size asinput
, except in the dimensiondim
where they are of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in both thevalues
andindices
tensors having 1 fewer dimension than theinput
tensor.Parameters: - input (Tensor) – the input tensor
- k (int) – k for the k-th smallest element
- dim (int, optional) – the dimension to find the kth value along
- keepdim (bool) – whether the output tensors have
dim
retained or not - out (tuple, optional) – the output tuple of (Tensor, LongTensor) can be optionally given to be used as output buffers
Example:
>>> x = torch.arange(1, 6) >>> x tensor([ 1., 2., 3., 4., 5.]) >>> torch.kthvalue(x, 4) (tensor(4.), tensor(3)) >>> x=torch.arange(1,7).resize_(2,3) >>> x tensor([[ 1., 2., 3.], [ 4., 5., 6.]]) >>> torch.kthvalue(x,2,0,True) (tensor([[ 4., 5., 6.]]), tensor([[ 1, 1, 1]]))
-
torch.
le
(input, other, out=None) → Tensor¶ Computes input≤other element-wise.
The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
Parameters: Returns: A
torch.ByteTensor
containing a 1 at each location where comparison is trueReturn type: Example:
>>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[ 1, 0], [ 1, 1]], dtype=torch.uint8)
-
torch.
lt
(input, other, out=None) → Tensor¶ Computes input<other element-wise.
The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
Parameters: Returns: A torch.ByteTensor containing a 1 at each location where comparison is true
Return type: Example:
>>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[ 0, 0], [ 1, 0]], dtype=torch.uint8)
-
torch.
max
()¶ -
torch.
max
(input) → Tensor
Returns the maximum value of all elements in the
input
tensor.Parameters: input (Tensor) – the input tensor Example:
>>> a = torch.randn(1, 3) >>> a tensor([[ 0.6763, 0.7445, -2.2369]]) >>> torch.max(a) tensor(0.7445)
-
torch.
max
(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
Returns the maximum value of each row of the
input
tensor in the given dimensiondim
. The second return value is the index location of each maximum value found (argmax).If
keepdim
isTrue
, the output tensors are of the same size asinput
except in the dimensiondim
where they are of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the output tensors having 1 fewer dimension thaninput
.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a tensor([[-1.2360, -0.2942, -0.1222, 0.8475], [ 1.1949, -1.1127, -2.2379, -0.6702], [ 1.5717, -0.9207, 0.1297, -1.8768], [-0.6172, 1.0036, -0.6060, -0.2432]]) >>> torch.max(a, 1) (tensor([ 0.8475, 1.1949, 1.5717, 1.0036]), tensor([ 3, 0, 0, 1]))
-
torch.
max
(input, other, out=None) → Tensor
Each element of the tensor
input
is compared with the corresponding element of the tensorother
and an element-wise maximum is taken.The shapes of
input
andother
don’t need to match, but they must be broadcastable.outi=maxNote
When the shapes do not match, the shape of the returned output tensor follows the broadcasting rules.
Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.2942, -0.7416, 0.2653, -0.1584]) >>> b = torch.randn(4) >>> b tensor([ 0.8722, -1.7421, -0.4141, -0.5055]) >>> torch.max(a, b) tensor([ 0.8722, -0.7416, 0.2653, -0.1584])
-
-
torch.
min
()¶ -
torch.
min
(input) → Tensor
Returns the minimum value of all elements in the
input
tensor.Parameters: input (Tensor) – the input tensor Example:
>>> a = torch.randn(1, 3) >>> a tensor([[ 0.6750, 1.0857, 1.7197]]) >>> torch.min(a) tensor(0.6750)
-
torch.
min
(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
Returns the minimum value of each row of the
input
tensor in the given dimensiondim
. The second return value is the index location of each minimum value found (argmin).If
keepdim
isTrue
, the output tensors are of the same size asinput
except in the dimensiondim
where they are of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the output tensors having 1 fewer dimension thaninput
.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a tensor([[-0.6248, 1.1334, -1.1899, -0.2803], [-1.4644, -0.2635, -0.3651, 0.6134], [ 0.2457, 0.0384, 1.0128, 0.7015], [-0.1153, 2.9849, 2.1458, 0.5788]]) >>> torch.min(a, 1) (tensor([-1.1899, -1.4644, 0.0384, -0.1153]), tensor([ 2, 0, 1, 0]))
-
torch.
min
(input, other, out=None) → Tensor
Each element of the tensor
input
is compared with the corresponding element of the tensorother
and an element-wise minimum is taken. The resulting tensor is returned.The shapes of
input
andother
don’t need to match, but they must be broadcastable.out_i = \min(tensor_i, other_i)Note
When the shapes do not match, the shape of the returned output tensor follows the broadcasting rules.
Parameters: Example:
>>> a = torch.randn(4) >>> a tensor([ 0.8137, -1.1740, -0.6460, 0.6308]) >>> b = torch.randn(4) >>> b tensor([-0.1369, 0.1555, 0.4019, -0.1929]) >>> torch.min(a, b) tensor([-0.1369, -1.1740, -0.6460, -0.1929])
-
-
torch.
ne
(input, other, out=None) → Tensor¶ Computes input \neq other element-wise.
The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
Parameters: Returns: A
torch.ByteTensor
containing a 1 at each location where comparison is true.Return type: Example:
>>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) tensor([[ 0, 1], [ 1, 0]], dtype=torch.uint8)
-
torch.
sort
(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)¶ Sorts the elements of the
input
tensor along a given dimension in ascending order by value.If
dim
is not given, the last dimension of the input is chosen.If
descending
isTrue
then the elements are sorted in descending order by value.A tuple of (sorted_tensor, sorted_indices) is returned, where the sorted_indices are the indices of the elements in the original input tensor.
Parameters: Example:
>>> x = torch.randn(3, 4) >>> sorted, indices = torch.sort(x) >>> sorted tensor([[-0.2162, 0.0608, 0.6719, 2.3332], [-0.5793, 0.0061, 0.6058, 0.9497], [-0.5071, 0.3343, 0.9553, 1.0960]]) >>> indices tensor([[ 1, 0, 2, 3], [ 3, 1, 0, 2], [ 0, 3, 1, 2]]) >>> sorted, indices = torch.sort(x, 0) >>> sorted tensor([[-0.5071, -0.2162, 0.6719, -0.5793], [ 0.0608, 0.0061, 0.9497, 0.3343], [ 0.6058, 0.9553, 1.0960, 2.3332]]) >>> indices tensor([[ 2, 0, 0, 1], [ 0, 1, 1, 2], [ 1, 2, 2, 0]])
-
torch.
topk
(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)¶ Returns the
k
largest elements of the giveninput
tensor along a given dimension.If
dim
is not given, the last dimension of the input is chosen.If
largest
isFalse
then the k smallest elements are returned.A tuple of (values, indices) is returned, where the indices are the indices of the elements in the original input tensor.
The boolean option
sorted
ifTrue
, will make sure that the returned k elements are themselves sortedParameters: - input (Tensor) – the input tensor
- k (int) – the k in “top-k”
- dim (int, optional) – the dimension to sort along
- largest (bool, optional) – controls whether to return largest or smallest elements
- sorted (bool, optional) – controls whether to return the elements in sorted order
- out (tuple, optional) – the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers
Example:
>>> x = torch.arange(1, 6) >>> x tensor([ 1., 2., 3., 4., 5.]) >>> torch.topk(x, 3) (tensor([ 5., 4., 3.]), tensor([ 4, 3, 2]))
Spectral Ops¶
-
torch.
fft
(input, signal_ndim, normalized=False) → Tensor¶ Complex-to-complex Discrete Fourier Transform
This method computes the complex-to-complex discrete Fourier transform. Ignoring the batch dimensions, it computes the following expression:
X[\omega_1, \dots, \omega_d] = \frac{1}{\prod_{i=1}^d N_i} \sum_{n_1=0}^{N_1} \dots \sum_{n_d=0}^{N_d} x[n_1, \dots, n_d] e^{-j\ 2 \pi \sum_{i=0}^d \frac{\omega_i n_i}{N_i}},where d =
signal_ndim
is number of dimensions for the signal, and N_i is the size of signal dimension i.This method supports 1D, 2D and 3D complex-to-complex transforms, indicated by
signal_ndim
.input
must be a tensor with last dimension of size 2, representing the real and imaginary components of complex numbers, and should have at leastsignal_ndim + 1
dimensions with optionally arbitrary number of leading batch dimensions. Ifnormalized
is set toTrue
, this normalizes the result by dividing it with \sqrt{\prod_{i=1}^K N_i} so that the operator is unitary.Returns the real and the imaginary parts together as one tensor of the same shape of
input
.The inverse of this function is
ifft()
.Warning
For CPU tensors, this method is currently only available with MKL. Check
torch.backends.mkl.is_available()
to check if MKL is installed.Parameters: Returns: A tensor containing the complex-to-complex Fourier transform result
Return type: Example:
>>> # unbatched 2D FFT >>> x = torch.randn(4, 3, 2) >>> torch.fft(x, 2) tensor([[[-0.0876, 1.7835], [-2.0399, -2.9754], [ 4.4773, -5.0119]], [[-1.5716, 2.7631], [-3.8846, 5.2652], [ 0.2046, -0.7088]], [[ 1.9938, -0.5901], [ 6.5637, 6.4556], [ 2.9865, 4.9318]], [[ 7.0193, 1.1742], [-1.3717, -2.1084], [ 2.0289, 2.9357]]]) >>> # batched 1D FFT >>> torch.fft(x, 1) tensor([[[ 1.8385, 1.2827], [-0.1831, 1.6593], [ 2.4243, 0.5367]], [[-0.9176, -1.5543], [-3.9943, -2.9860], [ 1.2838, -2.9420]], [[-0.8854, -0.6860], [ 2.4450, 0.0808], [ 1.3076, -0.5768]], [[-0.1231, 2.7411], [-0.3075, -1.7295], [-0.5384, -2.0299]]]) >>> # arbitrary number of batch dimensions, 2D FFT >>> x = torch.randn(3, 3, 5, 5, 2) >>> y = torch.fft(x, 2) >>> y.shape torch.Size([3, 3, 5, 5, 2])
-
torch.
ifft
(input, signal_ndim, normalized=False) → Tensor¶ Complex-to-complex Inverse Discrete Fourier Transform
This method computes the complex-to-complex inverse discrete Fourier transform. Ignoring the batch dimensions, it computes the following expression:
X[\omega_1, \dots, \omega_d] = \frac{1}{\prod_{i=1}^d N_i} \sum_{n_1=0}^{N_1} \dots \sum_{n_d=0}^{N_d} x[n_1, \dots, n_d] e^{\ j\ 2 \pi \sum_{i=0}^d \frac{\omega_i n_i}{N_i}},where d =
signal_ndim
is number of dimensions for the signal, and N_i is the size of signal dimension i.The argument specifications are almost identical with
fft()
. However, ifnormalized
is set toTrue
, this instead returns the results multiplied by \sqrt{\prod_{i=1}^d N_i}, to become a unitary operator. Therefore, to invert afft()
, thenormalized
argument should be set identically forfft()
.Returns the real and the imaginary parts together as one tensor of the same shape of
input
.The inverse of this function is
fft()
.Warning
For CPU tensors, this method is currently only available with MKL. Check
torch.backends.mkl.is_available()
to check if MKL is installed.Parameters: Returns: A tensor containing the complex-to-complex inverse Fourier transform result
Return type: Example:
>>> x = torch.randn(3, 3, 2) >>> x tensor([[[ 1.2766, 1.3680], [-0.8337, 2.0251], [ 0.9465, -1.4390]], [[-0.1890, 1.6010], [ 1.1034, -1.9230], [-0.9482, 1.0775]], [[-0.7708, -0.8176], [-0.1843, -0.2287], [-1.9034, -0.2196]]]) >>> y = torch.fft(x, 2) >>> torch.ifft(y, 2) # recover x tensor([[[ 1.2766, 1.3680], [-0.8337, 2.0251], [ 0.9465, -1.4390]], [[-0.1890, 1.6010], [ 1.1034, -1.9230], [-0.9482, 1.0775]], [[-0.7708, -0.8176], [-0.1843, -0.2287], [-1.9034, -0.2196]]])
-
torch.
rfft
(input, signal_ndim, normalized=False, onesided=True) → Tensor¶ Real-to-complex Discrete Fourier Transform
This method computes the real-to-complex discrete Fourier transform. It is mathematically equivalent with
fft()
with differences only in formats of the input and output.This method supports 1D, 2D and 3D real-to-complex transforms, indicated by
signal_ndim
.input
must be a tensor with at leastsignal_ndim
dimensions with optionally arbitrary number of leading batch dimensions. Ifnormalized
is set toTrue
, this normalizes the result by multiplying it with \sqrt{\prod_{i=1}^K N_i} so that the operator is unitary, where N_i is the size of signal dimension i.The real-to-complex Fourier transform results follow conjugate symmetry:
X[\omega_1, \dots, \omega_d] = X^*[N_1 - \omega_1, \dots, N_d - \omega_d],where the index arithmetic is computed modulus the size of the corresponding dimension, \ ^* is the conjugate operator, and d =
signal_ndim
.onesided
flag controls whether to avoid redundancy in the output results. If set toTrue
(default), the output will not be full complex result of shape (*, 2), where * is the shape ofinput
, but instead the last dimension will be halfed as of size \lfloor \frac{N_d}{2} \rfloor + 1.The inverse of this function is
irfft()
.Warning
For CPU tensors, this method is currently only available with MKL. Check
torch.backends.mkl.is_available()
to check if MKL is installed.Parameters: - input (Tensor) – the input tensor of at least
signal_ndim
dimensions - signal_ndim (int) – the number of dimensions in each signal.
signal_ndim
can only be 1, 2 or 3 - normalized (bool, optional) – controls whether to return normalized results.
Default:
False
- onesided (bool, optional) – controls whether to return half of results to
avoid redundancy Default:
True
Returns: A tensor containing the real-to-complex Fourier transform result
Return type: Example:
>>> x = torch.randn(5, 5) >>> torch.rfft(x, 2).shape torch.Size([5, 3, 2]) >>> torch.rfft(x, 2, onesided=False).shape torch.Size([5, 5, 2])
- input (Tensor) – the input tensor of at least
-
torch.
irfft
(input, signal_ndim, normalized=False, onesided=True, signal_sizes=None) → Tensor¶ Complex-to-real Inverse Discrete Fourier Transform
This method computes the complex-to-real inverse discrete Fourier transform. It is mathematically equivalent with
ifft()
with differences only in formats of the input and output.The argument specifications are almost identical with
ifft()
. Similar toifft()
, ifnormalized
is set toTrue
, this normalizes the result by multiplying it with \sqrt{\prod_{i=1}^K N_i} so that the operator is unitary, where N_i is the size of signal dimension i.Due to the conjugate symmetry,
input
do not need to contain the full complex frequency values. Roughly half of the values will be sufficient, as is the case wheninput
is given byrfft()
withrfft(signal, onesided=True)
. In such case, set theonesided
argument of this method toTrue
. Moreover, the original signal shape information can sometimes be lost, optionally setsignal_sizes
to be the size of the original signal (without the batch dimensions if in batched mode) to recover it with correct shape.Therefore, to invert an
rfft()
, thenormalized
andonesided
arguments should be set identically forirfft()
, and preferrably asignal_sizes
is given to avoid size mismatch. See the example below for a case of size mismatch.See
rfft()
for details on conjugate symmetry.The inverse of this function is
rfft()
.Warning
Generally speaking, the input of this function should contain values following conjugate symmetry. Note that even if
onesided
isTrue
, often symmetry on some part is still needed. When this requirement is not satisfied, the behavior ofirfft()
is undefined. Sincetorch.autograd.gradcheck()
estimates numerical Jacobian with point perturbations,irfft()
will almost certainly fail the check.Warning
For CPU tensors, this method is currently only available with MKL. Check
torch.backends.mkl.is_available()
to check if MKL is installed.Parameters: - input (Tensor) – the input tensor of at least
signal_ndim
+ 1
dimensions - signal_ndim (int) – the number of dimensions in each signal.
signal_ndim
can only be 1, 2 or 3 - normalized (bool, optional) – controls whether to return normalized results.
Default:
False
- onesided (bool, optional) – controls whether
input
was halfed to avoid redundancy, e.g., byrfft()
. Default:True
- signal_sizes (list or
torch.Size
, optional) – the size of the original signal (without batch dimension). Default:None
Returns: A tensor containing the complex-to-real inverse Fourier transform result
Return type: Example:
>>> x = torch.randn(4, 4) >>> torch.rfft(x, 2, onesided=True).shape torch.Size([4, 3, 2]) >>> >>> # notice that with onesided=True, output size does not determine the original signal size >>> x = torch.randn(4, 5) >>> torch.rfft(x, 2, onesided=True).shape torch.Size([4, 3, 2]) >>> >>> # now we use the original shape to recover x >>> x tensor([[-0.8992, 0.6117, -1.6091, -0.4155, -0.8346], [-2.1596, -0.0853, 0.7232, 0.1941, -0.0789], [-2.0329, 1.1031, 0.6869, -0.5042, 0.9895], [-0.1884, 0.2858, -1.5831, 0.9917, -0.8356]]) >>> y = torch.rfft(x, 2, onesided=True) >>> torch.irfft(y, 2, onesided=True, signal_sizes=x.shape) # recover x tensor([[-0.8992, 0.6117, -1.6091, -0.4155, -0.8346], [-2.1596, -0.0853, 0.7232, 0.1941, -0.0789], [-2.0329, 1.1031, 0.6869, -0.5042, 0.9895], [-0.1884, 0.2858, -1.5831, 0.9917, -0.8356]])
- input (Tensor) – the input tensor of at least
-
torch.
stft
(signal, frame_length, hop, fft_size=None, normalized=False, onesided=True, window=None, pad_end=0) → Tensor¶ Short-time Fourier transform (STFT).
Ignoring the batch dimension, this method computes the following expression:
X[m, \omega] = \sum_{k = 0}^{\text{frame_length}}% window[k]\ signal[m \times hop + k]\ e^{- j \frac{2 \pi \cdot \omega k}{\text{frame_length}}},where m is the index of the sliding window, and \omega is the frequency that 0 \leq \omega <
fft_size
. Whenreturn_onsesided
is the default valueTrue
, only values for \omega in range \left[0, 1, 2, \dots, \left\lfloor \frac{\text{fft_size}}{2} \right\rfloor + 1\right] are returned because the real-to-complex transform satisfies the Hermitian symmetry, i.e., X[m, \omega] = X[m, \text{fft_size} - \omega]^*.The input
signal
must be 1-D sequence (T) or 2-D a batch of sequences (N \times T). Iffft_size
isNone
, it is default to same value asframe_length
.window
can be a 1-D tensor of sizeframe_length
, e.g., seetorch.hann_window()
. Ifwindow
is the default valueNone
, it is treated as if having 1 everywhere in the frame.pad_end
indicates the amount of zero padding at the end ofsignal
before STFT. Ifnormalized
is set toTrue
, the function returns the normalized STFT results, i.e., multiplied by (frame\_length)^{-0.5}.Returns the real and the imaginary parts together as one tensor of size (* \times N \times 2), where * is the shape of input
signal
, N is the number of \omega s considered depending onfft_size
andreturn_onesided
, and each pair in the last dimension represents a complex number as real part and imaginary part.Parameters: - signal (Tensor) – the input tensor
- frame_length (int) – the size of window frame and STFT filter
- hop (int) – the distance between neighboring sliding window frames
- fft_size (int, optional) – size of Fourier transform. Default:
None
- normalized (bool, optional) – controls whether to return the normalized STFT results
Default:
False
- onesided (bool, optional) – controls whether to return half of results to
avoid redundancy Default:
True
- window (Tensor, optional) – the optional window function. Default:
None
- pad_end (int, optional) – implicit zero padding at the end of
signal
. Default: 0
Returns: A tensor containing the STFT result
Return type:
-
torch.
hann_window
(window_length, periodic=True, dtype=torch.float32)[source]¶ Hann window function.
This method computes the Hann window function:
w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = \sin^2 \left( \frac{\pi n}{N - 1} \right),where N is the full window size.
The input
window_length
is a positive integer controlling the returned window size.periodic
flag determines whether the returned window trims off the last duplicate value from the symmetric window and is ready to be used as a periodic window with functions liketorch.stft()
. Therefore, ifperiodic
is true, the N in above formula is in fact \text{window_length} + 1. Also, we always havetorch.hann_window(L, periodic=True)
equal totorch.hann_window(L + 1, periodic=False)[:-1])
.Note
If
window_length
=1, the returned window contains a single value 1.Parameters: - window_length (int) – the size of returned window
- periodic (bool, optional) – If True, returns a window to be used as periodic function. If False, return a symmetric window.
- dtype (
torch.dtype
, optional) – the desired type of returned window. Default: torch.float32
Returns: A 1-D tensor of size (\text{window_length},) containing the window
Return type:
-
torch.
hamming_window
(window_length, periodic=True, alpha=0.54, beta=0.46, dtype=torch.float32)[source]¶ Hamming window function.
This method computes the Hamming window function:
w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right),where N is the full window size.
The input
window_length
is a positive integer controlling the returned window size.periodic
flag determines whether the returned window trims off the last duplicate value from the symmetric window and is ready to be used as a periodic window with functions liketorch.stft()
. Therefore, ifperiodic
is true, the N in above formula is in fact \text{window_length} + 1. Also, we always havetorch.hamming_window(L, periodic=True)
equal totorch.hamming_window(L + 1, periodic=False)[:-1])
.Note
If
window_length
=1, the returned window contains a single value 1.Note
This is a generalized version of
torch.hann_window()
.Parameters: - window_length (int) – the size of returned window
- periodic (bool, optional) – If True, returns a window to be used as periodic function. If False, return a symmetric window.
- dtype (
torch.dtype
, optional) – the desired type of returned window. Default: torch.float32
Returns: A 1-D tensor of size (\text{window_length},) containing the window
Return type:
-
torch.
bartlett_window
(window_length, periodic=True, dtype=torch.float32)[source]¶ Bartlett window function.
This method computes the Bartlett window function:
\begin{split}w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ \end{cases},\end{split}where N is the full window size.
The input
window_length
is a positive integer controlling the returned window size.periodic
flag determines whether the returned window trims off the last duplicate value from the symmetric window and is ready to be used as a periodic window with functions liketorch.stft()
. Therefore, ifperiodic
is true, the N in above formula is in fact \text{window_length} + 1. Also, we always havetorch.bartlett_window(L, periodic=True)
equal totorch.bartlett_window(L + 1, periodic=False)[:-1])
.Note
If
window_length
=1, the returned window contains a single value 1.Parameters: - window_length (int) – the size of returned window
- periodic (bool, optional) – If True, returns a window to be used as periodic function. If False, return a symmetric window.
- dtype (
torch.dtype
, optional) – the desired type of returned window. Default: torch.float32
Returns: A 1-D tensor of size (\text{window_length},) containing the window
Return type:
Other Operations¶
-
torch.
cross
(input, other, dim=-1, out=None) → Tensor¶ Returns the cross product of vectors in dimension
dim
ofinput
andother
.input
andother
must have the same size, and the size of theirdim
dimension should be 3.If
dim
is not given, it defaults to the first dimension found with the size 3.Parameters: Example:
>>> a = torch.randn(4, 3) >>> a tensor([[-0.3956, 1.1455, 1.6895], [-0.5849, 1.3672, 0.3599], [-1.1626, 0.7180, -0.0521], [-0.1339, 0.9902, -2.0225]]) >>> b = torch.randn(4, 3) >>> b tensor([[-0.0257, -1.4725, -1.2251], [-1.1479, -0.7005, -1.9757], [-1.3904, 0.3726, -1.1836], [-0.9688, -0.7153, 0.2159]]) >>> torch.cross(a, b, dim=1) tensor([[ 1.0844, -0.5281, 0.6120], [-2.4490, -1.5687, 1.9792], [-0.8304, -1.3037, 0.5650], [-1.2329, 1.9883, 1.0551]]) >>> torch.cross(a, b) tensor([[ 1.0844, -0.5281, 0.6120], [-2.4490, -1.5687, 1.9792], [-0.8304, -1.3037, 0.5650], [-1.2329, 1.9883, 1.0551]])
-
torch.
diag
(input, diagonal=0, out=None) → Tensor¶ - If
input
is a vector (1-D tensor), then returns a 2-D square tensor with the elements ofinput
as the diagonal. - If
input
is a matrix (2-D tensor), then returns a 1-D tensor with the diagonal elements ofinput
.
The argument
diagonal
controls which diagonal to consider:- If
diagonal
= 0, it is the main diagonal. - If
diagonal
> 0, it is above the main diagonal. - If
diagonal
< 0, it is below the main diagonal.
Parameters: See also
torch.diagonal()
always returns the diagonal of its input.torch.diagflat()
always constructs a tensor with diagonal elements specified by the input.Examples:
Get the square matrix where the input vector is the diagonal:
>>> a = torch.randn(3) >>> a tensor([ 0.5950,-0.0872, 2.3298]) >>> torch.diag(a) tensor([[ 0.5950, 0.0000, 0.0000], [ 0.0000,-0.0872, 0.0000], [ 0.0000, 0.0000, 2.3298]]) >>> torch.diag(a, 1) tensor([[ 0.0000, 0.5950, 0.0000, 0.0000], [ 0.0000, 0.0000,-0.0872, 0.0000], [ 0.0000, 0.0000, 0.0000, 2.3298], [ 0.0000, 0.0000, 0.0000, 0.0000]])
Get the k-th diagonal of a given matrix:
>>> a = torch.randn(3, 3) >>> a tensor([[-0.4264, 0.0255,-0.1064], [ 0.8795,-0.2429, 0.1374], [ 0.1029,-0.6482,-1.6300]]) >>> torch.diag(a, 0) tensor([-0.4264,-0.2429,-1.6300]) >>> torch.diag(a, 1) tensor([ 0.0255, 0.1374])
- If
-
torch.
diagflat
(input, diagonal=0) → Tensor¶ - If
input
is a vector (1-D tensor), then returns a 2-D square tensor with the elements ofinput
as the diagonal. - If
input
is a tensor with more than one dimension, then returns a 2-D tensor with diagonal elements equal to a flattenedinput
.
The argument
offset
controls which diagonal to consider:- If
offset
= 0, it is the main diagonal. - If
offset
> 0, it is above the main diagonal. - If
offset
< 0, it is below the main diagonal.
Parameters: Examples:
>>> a = torch.randn(3) >>> a tensor([-0.2956, -0.9068, 0.1695]) >>> torch.diagflat(a) tensor([[-0.2956, 0.0000, 0.0000], [ 0.0000, -0.9068, 0.0000], [ 0.0000, 0.0000, 0.1695]]) >>> torch.diagflat(a, 1) tensor([[ 0.0000, -0.2956, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.9068, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.1695], [ 0.0000, 0.0000, 0.0000, 0.0000]]) >>> a = torch.randn(2, 2) >>> a tensor([[ 0.2094, -0.3018], [-0.1516, 1.9342]]) >>> torch.diagflat(a) tensor([[ 0.2094, 0.0000, 0.0000, 0.0000], [ 0.0000, -0.3018, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.1516, 0.0000], [ 0.0000, 0.0000, 0.0000, 1.9342]])
- If
-
torch.
diagonal
(input, offset=0) → Tensor¶ Returns a 1-D tensor with the diagonal elements of
input
.The argument
offset
controls which diagonal to consider:- If
offset
= 0, it is the main diagonal. - If
offset
> 0, it is above the main diagonal. - If
offset
< 0, it is below the main diagonal.
Parameters: Examples:
>>> a = torch.randn(3, 3) >>> a tensor([[-1.0854, 1.1431, -0.1752], [ 0.8536, -0.0905, 0.0360], [ 0.6927, -0.3735, -0.4945]]) >>> torch.diagonal(a, 0) tensor([-1.0854, -0.0905, -0.4945]) >>> torch.diagonal(a, 1) tensor([ 1.1431, 0.0360])
- If
-
torch.
einsum
(equation, operands) → Tensor¶ This function provides a way of computing multilinear expressions (i.e. sums of products) using the Einstein summation convention.
Parameters: - equation (string) – The equation is given in terms of lower case letters (indices) to be associated with each dimension of the operands and result. The left hand side lists the operands dimensions, separated by commas. There should be one index letter per tensor dimension. The right hand side follows after -> and gives the indices for the output. If the -> and right hand side are omitted, it implicitly defined as the alphabetically sorted list of all indices appearing exactly once in the left hand side. The indices not apprearing in the output are summed over after multiplying the operands entries. einsum does not implement diagonals (multiple occurences of a single index for one tensor, e.g. ii->i) and ellipses (...).
- operands (list of Tensors) – The operands to compute the Einstein sum of. Note that the operands are passed as a list, not as individual arguments.
Examples:
>>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', (x,y)) # outer product tensor([[-0.0570, -0.0286, -0.0231, 0.0197], [ 1.2616, 0.6335, 0.5113, -0.4351], [ 1.4452, 0.7257, 0.5857, -0.4984], [-0.4647, -0.2333, -0.1883, 0.1603], [-1.1130, -0.5588, -0.4510, 0.3838]]) >>> A = torch.randn(3,5,4) >>> l = torch.randn(2,5) >>> r = torch.randn(2,4) >>> torch.einsum('bn,anm,bm->ba', (l,A,r)) # compare torch.nn.functional.bilinear tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]]) >>> As = torch.randn(3,2,5) >>> Bs = torch.randn(3,5,4) >>> torch.einsum('bij,bjk->bik', (As, Bs)) # batch matrix multiplication tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]])
-
torch.
histc
(input, bins=100, min=0, max=0, out=None) → Tensor¶ Computes the histogram of a tensor.
The elements are sorted into equal width bins between
min
andmax
. Ifmin
andmax
are both zero, the minimum and maximum values of the data are used.Parameters: Returns: Histogram represented as a tensor
Return type: Example:
>>> torch.histc(torch.tensor([1., 2, 1]), bins=4, min=0, max=3) tensor([ 0., 2., 1., 0.])
-
torch.
renorm
(input, p, dim, maxnorm, out=None) → Tensor¶ Returns a tensor where each sub-tensor of
input
along dimensiondim
is normalized such that the p-norm of the sub-tensor is lower than the valuemaxnorm
Note
If the norm of a row is lower than maxnorm, the row is unchanged
Parameters: Example:
>>> x = torch.ones(3, 3) >>> x[1].fill_(2) tensor([ 2., 2., 2.]) >>> x[2].fill_(3) tensor([ 3., 3., 3.]) >>> x tensor([[ 1., 1., 1.], [ 2., 2., 2.], [ 3., 3., 3.]]) >>> torch.renorm(x, 1, 0, 5) tensor([[ 1.0000, 1.0000, 1.0000], [ 1.6667, 1.6667, 1.6667], [ 1.6667, 1.6667, 1.6667]])
-
torch.
trace
(input) → Tensor¶ Returns the sum of the elements of the diagonal of the input 2-D matrix.
Example:
>>> x = torch.arange(1, 10).view(3, 3) >>> x tensor([[ 1., 2., 3.], [ 4., 5., 6.], [ 7., 8., 9.]]) >>> torch.trace(x) tensor(15.)
-
torch.
tril
(input, diagonal=0, out=None) → Tensor¶ Returns the lower triangular part of the matrix (2-D tensor)
input
, the other elements of the result tensorout
are set to 0.The lower triangular part of the matrix is defined as the elements on and below the diagonal.
The argument
diagonal
controls which diagonal to consider. Ifdiagonal
= 0, all elements on and below the main diagonal are retained. A positive value includes just as many diagonals above the main diagonal, and similarly a negative value excludes just as many diagonals below the main diagonal. The main diagonal are the set of indices \lbrace (i, i) \rbrace for i \in [0, \min\{d_{1}, d_{2}\} - 1] where d_{1}, d_{2} are the dimensions of the matrix.Parameters: Example:
>>> a = torch.randn(3, 3) >>> a tensor([[-1.0813, -0.8619, 0.7105], [ 0.0935, 0.1380, 2.2112], [-0.3409, -0.9828, 0.0289]]) >>> torch.tril(a) tensor([[-1.0813, 0.0000, 0.0000], [ 0.0935, 0.1380, 0.0000], [-0.3409, -0.9828, 0.0289]]) >>> b = torch.randn(4, 6) >>> b tensor([[ 1.2219, 0.5653, -0.2521, -0.2345, 1.2544, 0.3461], [ 0.4785, -0.4477, 0.6049, 0.6368, 0.8775, 0.7145], [ 1.1502, 3.2716, -1.1243, -0.5413, 0.3615, 0.6864], [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0978]]) >>> torch.tril(b, diagonal=1) tensor([[ 1.2219, 0.5653, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.4785, -0.4477, 0.6049, 0.0000, 0.0000, 0.0000], [ 1.1502, 3.2716, -1.1243, -0.5413, 0.0000, 0.0000], [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0000]]) >>> torch.tril(b, diagonal=-1) tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.4785, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 1.1502, 3.2716, 0.0000, 0.0000, 0.0000, 0.0000], [-0.0614, -0.7344, -1.3164, 0.0000, 0.0000, 0.0000]])
-
torch.
triu
(input, diagonal=0, out=None) → Tensor¶ Returns the upper triangular part of the matrix (2-D tensor)
input
, the other elements of the result tensorout
are set to 0.The upper triangular part of the matrix is defined as the elements on and above the diagonal.
The argument
diagonal
controls which diagonal to consider. Ifdiagonal
= 0, all elements on and below the main diagonal are retained. A positive value excludes just as many diagonals above the main diagonal, and similarly a negative value includes just as many diagonals below the main diagonal. The main diagonal are the set of indices \lbrace (i, i) \rbrace for i \in [0, \min\{d_{1}, d_{2}\} - 1] where d_{1}, d_{2} are the dimensions of the matrix.Parameters: Example:
>>> a = torch.randn(3, 3) >>> a tensor([[ 0.2309, 0.5207, 2.0049], [ 0.2072, -1.0680, 0.6602], [ 0.3480, -0.5211, -0.4573]]) >>> torch.triu(a) tensor([[ 0.2309, 0.5207, 2.0049], [ 0.0000, -1.0680, 0.6602], [ 0.0000, 0.0000, -0.4573]]) >>> torch.triu(a, diagonal=1) tensor([[ 0.0000, 0.5207, 2.0049], [ 0.0000, 0.0000, 0.6602], [ 0.0000, 0.0000, 0.0000]]) >>> torch.triu(a, diagonal=-1) tensor([[ 0.2309, 0.5207, 2.0049], [ 0.2072, -1.0680, 0.6602], [ 0.0000, -0.5211, -0.4573]]) >>> b = torch.randn(4, 6) >>> b tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], [ 0.4333, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], [-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.2830]]) >>> torch.tril(b, diagonal=1) tensor([[ 0.5876, -0.0794, 0.0000, 0.0000, 0.0000, 0.0000], [-0.2447, 0.9556, -1.2919, 0.0000, 0.0000, 0.0000], [ 0.4333, 0.3146, 0.6576, -1.0432, 0.0000, 0.0000], [-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.0000]]) >>> torch.tril(b, diagonal=-1) tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [-0.2447, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.4333, 0.3146, 0.0000, 0.0000, 0.0000, 0.0000], [-0.9888, 1.0679, -1.3337, 0.0000, 0.0000, 0.0000]])
BLAS and LAPACK Operations¶
-
torch.
addbmm
(beta=1, mat, alpha=1, batch1, batch2, out=None) → Tensor¶ Performs a batch matrix-matrix product of matrices stored in
batch1
andbatch2
, with a reduced add step (all matrix multiplications get accumulated along the first dimension).mat
is added to the final result.batch1
andbatch2
must be 3-D tensors each containing the same number of matrices.If
batch1
is a (b \times n \times m) tensor,batch2
is a (b \times m \times p) tensor,mat
must be broadcastable with a (n \times p) tensor andout
will be a (n \times p) tensor.out = \beta\ mat + \alpha\ (\sum_{i=0}^{b} batch1_i \mathbin{@} batch2_i)For inputs of type FloatTensor or DoubleTensor, arguments
beta
andalpha
must be real numbers, otherwise they should be integers.Parameters: - beta (Number, optional) – multiplier for
mat
(\beta) - mat (Tensor) – matrix to be added
- alpha (Number, optional) – multiplier for batch1 @ batch2 (\alpha)
- batch1 (Tensor) – the first batch of matrices to be multiplied
- batch2 (Tensor) – the second batch of matrices to be multiplied
- out (Tensor, optional) – the output tensor
Example:
>>> M = torch.randn(3, 5) >>> batch1 = torch.randn(10, 3, 4) >>> batch2 = torch.randn(10, 4, 5) >>> torch.addbmm(M, batch1, batch2) tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]])
- beta (Number, optional) – multiplier for
-
torch.
addmm
(beta=1, mat, alpha=1, mat1, mat2, out=None) → Tensor¶ Performs a matrix multiplication of the matrices
mat1
andmat2
. The matrixmat
is added to the final result.If
mat1
is a (n \times m) tensor,mat2
is a (m \times p) tensor, thenmat
must be broadcastable with a (n \times p) tensor andout
will be a (n \times p) tensor.alpha
andbeta
are scaling factors on matrix-vector product betweenmat1
and :attr`mat2` and the added matrixmat
respectively.out = \beta\ mat + \alpha\ (mat1_i \mathbin{@} mat2_i)For inputs of type FloatTensor or DoubleTensor, arguments
beta
andalpha
must be real numbers, otherwise they should be integers.Parameters: Example:
>>> M = torch.randn(2, 3) >>> mat1 = torch.randn(2, 3) >>> mat2 = torch.randn(3, 3) >>> torch.addmm(M, mat1, mat2) tensor([[-4.8716, 1.4671, -1.3746], [ 0.7573, -3.9555, -2.8681]])
-
torch.
addmv
(beta=1, tensor, alpha=1, mat, vec, out=None) → Tensor¶ Performs a matrix-vector product of the matrix
mat
and the vectorvec
. The vectortensor
is added to the final result.If
mat
is a (n \times m) tensor,vec
is a 1-D tensor of size m, thentensor
must be broadcastable with a 1-D tensor of size n andout
will be 1-D tensor of size n.alpha
andbeta
are scaling factors on matrix-vector product betweenmat
andvec
and the added tensortensor
respectively.out = \beta\ tensor + \alpha\ (mat \mathbin{@} vec)For inputs of type FloatTensor or DoubleTensor, arguments
beta
andalpha
must be real numbers, otherwise they should be integersParameters: Example:
>>> M = torch.randn(2) >>> mat = torch.randn(2, 3) >>> vec = torch.randn(3) >>> torch.addmv(M, mat, vec) tensor([-0.3768, -5.5565])
-
torch.
addr
(beta=1, mat, alpha=1, vec1, vec2, out=None) → Tensor¶ Performs the outer-product of vectors
vec1
andvec2
and adds it to the matrixmat
.Optional values
beta
andalpha
are scaling factors on the outer product betweenvec1
andvec2
and the added matrixmat
respectively.out = \beta\ mat + \alpha\ (vec1 \otimes vec2)If
vec1
is a vector of size n andvec2
is a vector of size m, thenmat
must be broadcastable with a matrix of size (n \times m) andout
will be a matrix of size (n \times m).For inputs of type FloatTensor or DoubleTensor, arguments
beta
andalpha
must be real numbers, otherwise they should be integersParameters: - beta (Number, optional) – multiplier for
mat
(\beta) - mat (Tensor) – matrix to be added
- alpha (Number, optional) – multiplier for vec1 \otimes vec2 (\alpha)
- vec1 (Tensor) – the first vector of the outer product
- vec2 (Tensor) – the second vector of the outer product
- out (Tensor, optional) – the output tensor
Example:
>>> vec1 = torch.arange(1, 4) >>> vec2 = torch.arange(1, 3) >>> M = torch.zeros(3, 2) >>> torch.addr(M, vec1, vec2) tensor([[ 1., 2.], [ 2., 4.], [ 3., 6.]])
- beta (Number, optional) – multiplier for
-
torch.
baddbmm
(beta=1, mat, alpha=1, batch1, batch2, out=None) → Tensor¶ Performs a batch matrix-matrix product of matrices in
batch1
andbatch2
.mat
is added to the final result.batch1
andbatch2
must be 3-D tensors each containing the same number of matrices.If
batch1
is a (b \times n \times m) tensor,batch2
is a (b \times m \times p) tensor, thenmat
must be broadcastable with a (b \times n \times p) tensor andout
will be a (b \times n \times p) tensor. Bothalpha
andbeta
mean the same as the scaling factors used intorch.addbmm()
.out_i = \beta\ mat_i + \alpha\ (batch1_i \mathbin{@} batch2_i)For inputs of type FloatTensor or DoubleTensor, arguments
beta
andalpha
must be real numbers, otherwise they should be integers.Parameters: - beta (Number, optional) – multiplier for
mat
(\beta) - mat (Tensor) – the tensor to be added
- alpha (Number, optional) – multiplier for batch1 @ batch2 (\alpha)
- batch1 (Tensor) – the first batch of matrices to be multiplied
- batch2 (Tensor) – the second batch of matrices to be multiplied
- out (Tensor, optional) – the output tensor
Example:
>>> M = torch.randn(10, 3, 5) >>> batch1 = torch.randn(10, 3, 4) >>> batch2 = torch.randn(10, 4, 5) >>> torch.baddbmm(M, batch1, batch2).size() torch.Size([10, 3, 5])
- beta (Number, optional) – multiplier for
-
torch.
bmm
(batch1, batch2, out=None) → Tensor¶ Performs a batch matrix-matrix product of matrices stored in
batch1
andbatch2
.batch1
andbatch2
must be 3-D tensors each containing the same number of matrices.If
batch1
is a (b \times n \times m) tensor,batch2
is a (b \times m \times p) tensor,out
will be a (b \times n \times p) tensor.out_i = batch1_i \mathbin{@} batch2_iNote
This function does not broadcast. For broadcasting matrix products, see
torch.matmul()
.Parameters: Example:
>>> batch1 = torch.randn(10, 3, 4) >>> batch2 = torch.randn(10, 4, 5) >>> res = torch.bmm(batch1, batch2) >>> res.size() torch.Size([10, 3, 5])
-
torch.
btrifact
(A, info=None, pivot=True)[source]¶ Batch LU factorization.
Returns a tuple containing the LU factorization and pivots. Pivoting is done if
pivot
is set.The optional argument
info
stores information if the factorization succeeded for each minibatch example. Theinfo
is provided as an IntTensor, its values will be filled from dgetrf and a non-zero value indicates an error occurred. Specifically, the values are from cublas if cuda is being used, otherwise LAPACK.Warning
The
info
argument is deprecated in favor oftorch.btrifact_with_info()
.Parameters: Returns: A tuple containing factorization and pivots.
Example:
>>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = torch.btrifact(A) >>> A_LU tensor([[[ 1.3506, 2.5558, -0.0816], [ 0.1684, 1.1551, 0.1940], [ 0.1193, 0.6189, -0.5497]], [[ 0.4526, 1.2526, -0.3285], [-0.7988, 0.7175, -0.9701], [ 0.2634, -0.9255, -0.3459]]]) >>> pivots tensor([[ 3, 3, 3], [ 3, 3, 3]], dtype=torch.int32)
-
torch.
btrifact_with_info
(A, pivot=True) -> (Tensor, IntTensor, IntTensor)¶ Batch LU factorization with additional error information.
This is a version of
torch.btrifact()
that always creates an info IntTensor, and returns it as the third return value.Parameters: Returns: A tuple containing factorization, pivots, and an IntTensor where non-zero values indicate whether factorization for each minibatch sample succeeds.
Example:
>>> A = torch.randn(2, 3, 3) >>> A_LU, pivots, info = A.btrifact_with_info() >>> if info.nonzero().size(0) == 0: >>> print('LU factorization succeeded for all samples!') LU factorization succeeded for all samples!
-
torch.
btrisolve
(b, LU_data, LU_pivots) → Tensor¶ Batch LU solve.
Returns the LU solve of the linear system Ax = b.
Parameters: - b (Tensor) – the RHS tensor
- LU_data (Tensor) – the pivoted LU factorization of A from
btrifact()
. - LU_pivots (IntTensor) – the pivots of the LU factorization
Example:
>>> A = torch.randn(2, 3, 3) >>> b = torch.randn(2, 3) >>> A_LU = torch.btrifact(A) >>> x = torch.btrisolve(b, *A_LU) >>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2)) tensor(1.00000e-07 * 2.8312)
-
torch.
btriunpack
(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True)[source]¶ Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor.
Returns a tuple of tensors as
(the pivots, the L tensor, the U tensor)
.Parameters: Example:
>>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = A.btrifact() >>> P, A_L, A_U = torch.btriunpack(A_LU, pivots) >>> >>> # can recover A from factorization >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
-
torch.
dot
(tensor1, tensor2) → Tensor¶ Computes the dot product (inner product) of two tensors.
Note
This function does not broadcast.
Example:
>>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1])) tensor(7)
-
torch.
eig
(a, eigenvectors=False, out=None) -> (Tensor, Tensor)¶ Computes the eigenvalues and eigenvectors of a real square matrix.
Parameters: Returns: A tuple containing
- e (Tensor): the right eigenvalues of
a
- v (Tensor): the eigenvectors of
a
ifeigenvectors
isTrue
; otherwise an empty tensor
Return type: - e (Tensor): the right eigenvalues of
-
torch.
gels
(B, A, out=None) → Tensor¶ Computes the solution to the least squares and least norm problems for a full rank matrix A of size (m \times n) and a matrix B of size (n \times k).
If m \geq n,
gels()
solves the least-squares problem:\begin{array}{ll} \min_X & \|AX-B\|_2. \end{array}If m < n,
gels()
solves the least-norm problem:\begin{array}{ll} \min_X & \|X\|_2 & \mbox{subject to} & AX = B. \end{array}Returned tensor X has shape (\max(m, n) \times k). The first n rows of X contains the solution. If :math`m geq n`, the residual sum of squares for the solution in each column is given by the sum of squares of elements in the remaining m - n rows of that column.
Parameters: Returns: A tuple containing:
- X (Tensor): the least squares solution
- qr (Tensor): the details of the QR factorization
Return type: Note
The returned matrices will always be transposed, irrespective of the strides of the input matrices. That is, they will have stride (1, m) instead of (m, 1).
Example:
>>> A = torch.tensor([[1., 1, 1], [2, 3, 4], [3, 5, 2], [4, 2, 5], [5, 4, 3]]) >>> B = torch.tensor([[-10., -3], [ 12, 14], [ 14, 12], [ 16, 16], [ 18, 16]]) >>> X, _ = torch.gels(B, A) >>> X tensor([[ 2.0000, 1.0000], [ 1.0000, 1.0000], [ 1.0000, 2.0000], [ 10.9635, 4.8501], [ 8.9332, 5.2418]])
-
torch.
geqrf
(input, out=None) -> (Tensor, Tensor)¶ This is a low-level function for calling LAPACK directly.
You’ll generally want to use
torch.qr()
instead.Computes a QR decomposition of
input
, but without constructing Q and R as explicit separate matrices.Rather, this directly calls the underlying LAPACK function ?geqrf which produces a sequence of ‘elementary reflectors’.
See LAPACK documentation for geqrf for further details.
Parameters:
-
torch.
ger
(vec1, vec2, out=None) → Tensor¶ Outer product of
vec1
andvec2
. Ifvec1
is a vector of size n andvec2
is a vector of size m, thenout
must be a matrix of size (n \times m).Note
This function does not broadcast.
Parameters: Example:
>>> v1 = torch.arange(1, 5) >>> v2 = torch.arange(1, 4) >>> torch.ger(v1, v2) tensor([[ 1., 2., 3.], [ 2., 4., 6.], [ 3., 6., 9.], [ 4., 8., 12.]])
-
torch.
gesv
(B, A, out=None) -> (Tensor, Tensor)¶ This function returns the solution to the system of linear equations represented by AX = B and the LU factorization of A, in order as a tuple X, LU.
LU contains L and U factors for LU factorization of A.
A
has to be a square and non-singular matrix (2-D tensor).If A is an (m \times m) matrix and B is (m \times k), the result LU is (m \times m) and X is (m \times k).
Note
Irrespective of the original strides, the returned matrices X and LU will be transposed, i.e. with strides (1, m) instead of (m, 1).
Parameters: Example:
>>> A = torch.tensor([[6.80, -2.11, 5.66, 5.97, 8.23], [-6.05, -3.30, 5.36, -4.44, 1.08], [-0.45, 2.58, -2.70, 0.27, 9.04], [8.32, 2.71, 4.35, -7.17, 2.14], [-9.67, -5.14, -7.26, 6.08, -6.87]]).t() >>> B = torch.tensor([[4.02, 6.19, -8.22, -7.57, -3.03], [-1.56, 4.00, -8.67, 1.75, 2.86], [9.81, -4.09, -4.57, -8.61, 8.99]]).t() >>> X, LU = torch.gesv(B, A) >>> torch.dist(B, torch.mm(A, X)) tensor(1.00000e-06 * 7.0977)
-
torch.
inverse
(input, out=None) → Tensor¶ Takes the inverse of the square matrix
input
.Note
Irrespective of the original strides, the returned matrix will be transposed, i.e. with strides (1, m) instead of (m, 1)
Parameters: Example:
>>> x = torch.rand(4, 4) >>> y = torch.inverse(x) >>> z = torch.mm(x, y) >>> z tensor([[ 1.0000, -0.0000, -0.0000, 0.0000], [ 0.0000, 1.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 1.0000, 0.0000], [ 0.0000, -0.0000, -0.0000, 1.0000]]) >>> torch.max(torch.abs(z - torch.eye(4))) # Max nonzero tensor(1.00000e-07 * 1.1921)
-
torch.
det
(A) → Tensor¶ Calculates determinant of a 2D square tensor.
Note
Backward through
det()
internally uses SVD results whenA
is not invertible. In this case, double backward throughdet()
will be unstable in whenA
doesn’t have distinct singular values. Seesvd()
for details.Parameters: A (Tensor) – The input 2D square tensor Example:
>>> A = torch.randn(3, 3) >>> torch.det(A) tensor(3.7641)
-
torch.
logdet
(A) → Tensor¶ Calculates log determinant of a 2D square tensor.
Note
Result is
-inf
ifA
has zero log determinant, and isnan
ifA
has negative determinant.Note
Backward through
logdet()
internally uses SVD results whenA
is not invertible. In this case, double backward throughlogdet()
will be unstable in whenA
doesn’t have distinct singular values. Seesvd()
for details.Parameters: A (Tensor) – The input 2D square tensor Example:
>>> A = torch.randn(3, 3) >>> torch.det(A) tensor(0.2611) >>> torch.logdet(A) tensor(-1.3430)
-
torch.
slogdet
(A) -> (Tensor, Tensor)¶ Calculates the sign and log value of a 2D square tensor’s determinant.
Note
If
A
has zero determinant, this returns(0, -inf)
.Note
Backward through
slogdet()
internally uses SVD results whenA
is not invertible. In this case, double backward throughslogdet()
will be unstable in whenA
doesn’t have distinct singular values. Seesvd()
for details.Parameters: A (Tensor) – The input 2D square tensor Returns: A tuple containing the sign of the determinant, and the log value of the absolute determinant. Example:
>>> A = torch.randn(3, 3) >>> torch.det(A) tensor(-4.8215) >>> torch.logdet(A) tensor(nan) >>> torch.slogdet(A) (tensor(-1.), tensor(1.5731))
-
torch.
matmul
(tensor1, tensor2, out=None) → Tensor¶ Matrix product of two tensors.
The behavior depends on the dimensionality of the tensors as follows:
- If both tensors are 1-dimensional, the dot product (scalar) is returned.
- If both arguments are 2-dimensional, the matrix-matrix product is returned.
- If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
- If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.
- If both arguments are at least 1-dimensional and at least one argument is
N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first
argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the
batched matrix multiply and removed after. If the second argument is 1-dimensional, a
1 is appended to its dimension for the purpose of the batched matrix multiple and removed after.
The non-matrix (i.e. batch) dimensions are broadcasted (and thus
must be broadcastable). For example, if
tensor1
is a (j \times 1 \times n \times m) tensor andtensor2
is a (k \times m \times p) tensor,out
will be an (j \times k \times n \times p) tensor.
Note
The 1-dimensional dot product version of this function does not support an
out
parameter.Parameters: Example:
>>> # vector x vector >>> tensor1 = torch.randn(3) >>> tensor2 = torch.randn(3) >>> torch.matmul(tensor1, tensor2).size() torch.Size([]) >>> # matrix x vector >>> tensor1 = torch.randn(3, 4) >>> tensor2 = torch.randn(4) >>> torch.matmul(tensor1, tensor2).size() torch.Size([3]) >>> # batched matrix x broadcasted vector >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(4) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3]) >>> # batched matrix x batched matrix >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(10, 4, 5) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5]) >>> # batched matrix x broadcasted matrix >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(4, 5) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5])
-
torch.
mm
(mat1, mat2, out=None) → Tensor¶ Performs a matrix multiplication of the matrices
mat1
andmat2
.If
mat1
is a (n \times m) tensor,mat2
is a (m \times p) tensor,out
will be a (n \times p) tensor.Note
This function does not broadcast. For broadcasting matrix products, see
torch.matmul()
.Parameters: Example:
>>> mat1 = torch.randn(2, 3) >>> mat2 = torch.randn(3, 3) >>> torch.mm(mat1, mat2) tensor([[ 0.4851, 0.5037, -0.3633], [-0.0760, -3.6705, 2.4784]])
-
torch.
mv
(mat, vec, out=None) → Tensor¶ Performs a matrix-vector product of the matrix
mat
and the vectorvec
.If
mat
is a (n \times m) tensor,vec
is a 1-D tensor of size m,out
will be 1-D of size n.Note
This function does not broadcast.
Parameters: Example:
>>> mat = torch.randn(2, 3) >>> vec = torch.randn(3) >>> torch.mv(mat, vec) tensor([ 1.0404, -0.6361])
-
torch.
orgqr
(a, tau) → Tensor¶ Computes the orthogonal matrix Q of a QR factorization, from the (a, tau) tuple returned by
torch.geqrf()
.This directly calls the underlying LAPACK function ?orgqr. See LAPACK documentation for orgqr for further details.
Parameters: - a (Tensor) – the a from
torch.geqrf()
. - tau (Tensor) – the tau from
torch.geqrf()
.
- a (Tensor) – the a from
-
torch.
ormqr
(a, tau, mat, left=True, transpose=False) -> (Tensor, Tensor)¶ Multiplies mat by the orthogonal Q matrix of the QR factorization formed by
torch.geqrf()
that is represented by (a, tau).This directly calls the underlying LAPACK function ?ormqr. See LAPACK documentation for ormqr for further details.
Parameters: - a (Tensor) – the a from
torch.geqrf()
. - tau (Tensor) – the tau from
torch.geqrf()
. - mat (Tensor) – the matrix to be multiplied.
- a (Tensor) – the a from
-
torch.
potrf
(a, upper=True, out=None) → Tensor¶ Computes the Cholesky decomposition of a symmetric positive-definite matrix A.
If
upper
isTrue
, the returned matrix U is upper-triangular, and the decomposition has the form:A = U^TUIf
upper
isFalse
, the returned matrix L is lower-triangular, and the decomposition has the form:A = LL^TParameters: Example:
>>> a = torch.randn(3, 3) >>> a = torch.mm(a, a.t()) # make symmetric positive definite >>> u = torch.potrf(a) >>> a tensor([[ 2.4112, -0.7486, 1.4551], [-0.7486, 1.3544, 0.1294], [ 1.4551, 0.1294, 1.6724]]) >>> u tensor([[ 1.5528, -0.4821, 0.9371], [ 0.0000, 1.0592, 0.5486], [ 0.0000, 0.0000, 0.7023]]) >>> torch.mm(u.t(), u) tensor([[ 2.4112, -0.7486, 1.4551], [-0.7486, 1.3544, 0.1294], [ 1.4551, 0.1294, 1.6724]])
-
torch.
potri
(u, upper=True, out=None) → Tensor¶ Computes the inverse of a positive semidefinite matrix given its Cholesky factor
u
: returns matrix invIf
upper
isTrue
or not provided,u
is upper triangular such that:inv = (u^T u)^{-1}If
upper
isFalse
,u
is lower triangular such that:inv = (uu^{T})^{-1}Parameters: Example:
>>> a = torch.randn(3, 3) >>> a = torch.mm(a, a.t()) # make symmetric positive definite >>> u = torch.potrf(a) >>> a tensor([[ 0.9935, -0.6353, 1.5806], [ -0.6353, 0.8769, -1.7183], [ 1.5806, -1.7183, 10.6618]]) >>> torch.potri(u) tensor([[ 1.9314, 1.2251, -0.0889], [ 1.2251, 2.4439, 0.2122], [-0.0889, 0.2122, 0.1412]]) >>> a.inverse() tensor([[ 1.9314, 1.2251, -0.0889], [ 1.2251, 2.4439, 0.2122], [-0.0889, 0.2122, 0.1412]])
-
torch.
potrs
(b, u, upper=True, out=None) → Tensor¶ Solves a linear system of equations with a positive semidefinite matrix to be inverted given its Cholesky factor matrix
u
.If
upper
isTrue
or not provided,u
is upper triangular and c is returned such that:c = (u^T u)^{-1} bIf
upper
isFalse
,u
is and lower triangular and c is returned such that:c = (u u^T)^{-1} bNote
b
is always a 2-D tensor, use b.unsqueeze(1) to convert a vector.Parameters: Example:
>>> a = torch.randn(3, 3) >>> a = torch.mm(a, a.t()) # make symmetric positive definite >>> u = torch.potrf(a) >>> a tensor([[ 0.7747, -1.9549, 1.3086], [-1.9549, 6.7546, -5.4114], [ 1.3086, -5.4114, 4.8733]]) >>> b = torch.randn(3, 2) >>> b tensor([[-0.6355, 0.9891], [ 0.1974, 1.4706], [-0.4115, -0.6225]]) >>> torch.potrs(b,u) tensor([[ -8.1625, 19.6097], [ -5.8398, 14.2387], [ -4.3771, 10.4173]]) >>> torch.mm(a.inverse(),b) tensor([[ -8.1626, 19.6097], [ -5.8398, 14.2387], [ -4.3771, 10.4173]])
-
torch.
pstrf
(a, upper=True, out=None) -> (Tensor, Tensor)¶ Computes the pivoted Cholesky decomposition of a positive semidefinite matrix
a
. returns matrices u and piv.If
upper
isTrue
or not provided, u is upper triangular such that a = p^T u^T u p, with p the permutation given by piv.If
upper
isFalse
, u is lower triangular such that a = p^T u u^T p.Parameters: Example:
>>> a = torch.randn(3, 3) >>> a = torch.mm(a, a.t()) # make symmetric positive definite >>> a tensor([[ 3.5405, -0.4577, 0.8342], [-0.4577, 1.8244, -0.1996], [ 0.8342, -0.1996, 3.7493]]) >>> u,piv = torch.pstrf(a) >>> u tensor([[ 1.9363, 0.4308, -0.1031], [ 0.0000, 1.8316, -0.2256], [ 0.0000, 0.0000, 1.3277]]) >>> piv tensor([ 2, 0, 1], dtype=torch.int32) >>> p = torch.eye(3).index_select(0,piv.long()).index_select(0,piv.long()).t() # make pivot permutation >>> torch.mm(torch.mm(p.t(),torch.mm(u.t(),u)),p) # reconstruct tensor([[ 3.5405, -0.4577, 0.8342], [-0.4577, 1.8244, -0.1996], [ 0.8342, -0.1996, 3.7493]])
-
torch.
qr
(input, out=None) -> (Tensor, Tensor)¶ Computes the QR decomposition of a matrix
input
, and returns matrices Q and R such that \text{input} = Q R, with Q being an orthogonal matrix and R being an upper triangular matrix.This returns the thin (reduced) QR factorization.
Note
precision may be lost if the magnitudes of the elements of
input
are largeNote
While it should always give you a valid decomposition, it may not give you the same one across platforms - it will depend on your LAPACK implementation.
Note
Irrespective of the original strides, the returned matrix Q will be transposed, i.e. with strides (1, m) instead of (m, 1).
Parameters: Example:
>>> a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) >>> q, r = torch.qr(a) >>> q tensor([[-0.8571, 0.3943, 0.3314], [-0.4286, -0.9029, -0.0343], [ 0.2857, -0.1714, 0.9429]]) >>> r tensor([[ -14.0000, -21.0000, 14.0000], [ 0.0000, -175.0000, 70.0000], [ 0.0000, 0.0000, -35.0000]]) >>> torch.mm(q, r).round() tensor([[ 12., -51., 4.], [ 6., 167., -68.], [ -4., 24., -41.]]) >>> torch.mm(q.t(), q).round() tensor([[ 1., 0., 0.], [ 0., 1., -0.], [ 0., -0., 1.]])
-
torch.
svd
(input, some=True, out=None) -> (Tensor, Tensor, Tensor)¶ U, S, V = torch.svd(A) returns the singular value decomposition of a real matrix A of size (n x m) such that A = USV^T.
U is of shape (n \times n).
S is a diagonal matrix of shape (n \times m), represented as a vector of size \min(n, m) containing the non-negative diagonal entries.
V is of shape (m \times m).
If
some
isTrue
(default), the returned U and V matrices will contain only min(n, m) orthonormal columns.Note
Irrespective of the original strides, the returned matrix U will be transposed, i.e. with strides (1, n) instead of (n, 1).
Note
Extra care needs to be taken when backward through U and V outputs. Such operation is really only stable when
input
is full rank with all distinct singular values. Otherwise,NaN
can appear as the gradients are not properly defined. Also, notice that double backward will usually do an additional backward through U and V even if the original backward is only on S.Note
When
some
=False
, the gradients onU[:, min(n, m):]
andV[:, min(n, m):]
will be ignored in backward as those vectors can be arbitrary bases of the subspaces.Parameters: Example:
>>> a = torch.tensor([[8.79, 6.11, -9.15, 9.57, -3.49, 9.84], [9.93, 6.91, -7.93, 1.64, 4.02, 0.15], [9.83, 5.04, 4.86, 8.83, 9.80, -8.99], [5.45, -0.27, 4.85, 0.74, 10.00, -6.02], [3.16, 7.98, 3.01, 5.80, 4.27, -5.31]]).t() >>> u, s, v = torch.svd(a) >>> u tensor([[-0.5911, 0.2632, 0.3554, 0.3143, 0.2299], [-0.3976, 0.2438, -0.2224, -0.7535, -0.3636], [-0.0335, -0.6003, -0.4508, 0.2334, -0.3055], [-0.4297, 0.2362, -0.6859, 0.3319, 0.1649], [-0.4697, -0.3509, 0.3874, 0.1587, -0.5183], [ 0.2934, 0.5763, -0.0209, 0.3791, -0.6526]]) >>> s tensor([ 27.4687, 22.6432, 8.5584, 5.9857, 2.0149]) >>> v tensor([[-0.2514, 0.8148, -0.2606, 0.3967, -0.2180], [-0.3968, 0.3587, 0.7008, -0.4507, 0.1402], [-0.6922, -0.2489, -0.2208, 0.2513, 0.5891], [-0.3662, -0.3686, 0.3859, 0.4342, -0.6265], [-0.4076, -0.0980, -0.4933, -0.6227, -0.4396]]) >>> torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t())) tensor(1.00000e-06 * 9.3738)
-
torch.
symeig
(input, eigenvectors=False, upper=True, out=None) -> (Tensor, Tensor)¶ This function returns eigenvalues and eigenvectors of a real symmetric matrix
input
, represented by a tuple (e, V).input
and V are (m \times m) matrices and e is a m dimensional vector.This function calculates all eigenvalues (and vectors) of
input
such that input = V diag(e) V^T.The boolean argument
eigenvectors
defines computation of eigenvectors or eigenvalues only.If it is
False
, only eigenvalues are computed. If it isTrue
, both eigenvalues and eigenvectors are computed.Since the input matrix
input
is supposed to be symmetric, only the upper triangular portion is used by default.If
upper
isFalse
, then lower triangular portion is used.Note: Irrespective of the original strides, the returned matrix V will be transposed, i.e. with strides (1, m) instead of (m, 1).
Parameters: Examples:
>>> a = torch.tensor([[ 1.96, 0.00, 0.00, 0.00, 0.00], [-6.49, 3.80, 0.00, 0.00, 0.00], [-0.47, -6.39, 4.17, 0.00, 0.00], [-7.20, 1.50, -1.51, 5.70, 0.00], [-0.65, -6.34, 2.67, 1.80, -7.10]]).t() >>> e, v = torch.symeig(a, eigenvectors=True) >>> e tensor([-11.0656, -6.2287, 0.8640, 8.8655, 16.0948]) >>> v tensor([[-0.2981, -0.6075, 0.4026, -0.3745, 0.4896], [-0.5078, -0.2880, -0.4066, -0.3572, -0.6053], [-0.0816, -0.3843, -0.6600, 0.5008, 0.3991], [-0.0036, -0.4467, 0.4553, 0.6204, -0.4564], [-0.8041, 0.4480, 0.1725, 0.3108, 0.1622]])
-
torch.
trtrs
(b, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)¶ Solves a system of equations with a triangular coefficient matrix A and multiple right-hand sides b.
In particular, solves AX = b and assumes A is upper-triangular with the default keyword arguments.
This method is NOT implemented for CUDA tensors.
Parameters: - A (Tensor) – the input triangular coefficient matrix
- b (Tensor) – multiple right-hand sides. Each column of b is a right-hand side for the system of equations.
- upper (bool, optional) – whether to solve the upper-triangular system of equations (default) or the lower-triangular system of equations. Default: True.
- transpose (bool, optional) – whether A should be transposed before being sent into the solver. Default: False.
- unitriangular (bool, optional) – whether A is unit triangular. If True, the diagonal elements of A are assumed to be 1 and not referenced from A. Default: False.
Returns: A tuple (X, M) where M is a clone of A and X is the solution to AX = b (or whatever variant of the system of equations, depending on the keyword arguments.)
- Shape:
- A: (N, N)
- b: (N, C)
- output[0]: (N, C)
- output[1]: (N, N)
Examples:
>>> A = torch.randn(2, 2).triu() >>> A tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]) >>> b = torch.randn(2, 3) >>> b tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) >>> torch.trtrs(b, A) (tensor([[ 1.7840, 2.9045, -2.5405], [ 1.9319, 0.9269, -1.2826]]), tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]))