torch

The torch package contains data structures for multi-dimensional tensors and mathematical operations over these are defined. Additionally, it provides many utilities for efficient serializing of Tensors and arbitrary types, and other useful utilities.

It has a CUDA counterpart, that enables you to run your tensor computations on an NVIDIA GPU with compute capability >= 3.0.

Tensors

torch.is_tensor(obj)[source]

Returns True if obj is a pytorch tensor.

Parameters:obj (Object) – Object to test
torch.is_storage(obj)[source]

Returns True if obj is a pytorch storage object.

Parameters:obj (Object) – Object to test
torch.set_default_tensor_type(t)[source]
torch.numel(input) → int

Returns the total number of elements in the input tensor.

Parameters:input (Tensor) – the input tensor

Example:

>>> a = torch.randn(1,2,3,4,5)
>>> torch.numel(a)
120
>>> a = torch.zeros(4,4)
>>> torch.numel(a)
16
torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None)[source]

Set options for printing. Items shamelessly taken from Numpy

Parameters:
  • precision – Number of digits of precision for floating point output (default 8).
  • threshold – Total number of array elements which trigger summarization rather than full repr (default 1000).
  • edgeitems – Number of array items in summary at beginning and end of each dimension (default 3).
  • linewidth – The number of characters per line for the purpose of inserting line breaks (default 80). Thresholded matricies will ignore this parameter.
  • profile – Sane defaults for pretty printing. Can override with any of the above options. (default, short, full)

Creation Ops

torch.eye(n, m=None, out=None)

Returns a 2-D tensor with ones on the diagonal and zeros elsewhere.

Parameters:
  • n (int) – the number of rows
  • m (int, optional) – the number of columns with default being n
  • out (Tensor, optional) – the output tensor
Returns:

A 2-D tensor with ones on the diagonal and zeros elsewhere

Return type:

Tensor

Example:

>>> torch.eye(3)
 1  0  0
 0  1  0
 0  0  1
[torch.FloatTensor of size 3x3]
torch.from_numpy(ndarray) → Tensor

Creates a Tensor from a numpy.ndarray.

The returned tensor and ndarray share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa. The returned tensor is not resizable.

Example:

>>> a = numpy.array([1, 2, 3])
>>> t = torch.from_numpy(a)
>>> t
torch.LongTensor([1, 2, 3])
>>> t[0] = -1
>>> a
array([-1,  2,  3])
torch.linspace(start, end, steps=100, out=None) → Tensor

Returns a one-dimensional tensor of steps equally spaced points between start and end

The output tensor is 1-D of size steps

Parameters:
  • start (float) – the starting value for the set of points
  • end (float) – the ending value for the set of points
  • steps (int) – number of points to sample between start and end
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.linspace(3, 10, steps=5)

  3.0000
  4.7500
  6.5000
  8.2500
 10.0000
[torch.FloatTensor of size 5]

>>> torch.linspace(-10, 10, steps=5)

-10
 -5
  0
  5
 10
[torch.FloatTensor of size 5]

>>> torch.linspace(start=-10, end=10, steps=5)

-10
 -5
  0
  5
 10
[torch.FloatTensor of size 5]
torch.logspace(start, end, steps=100, out=None) → Tensor

Returns a one-dimensional tensor of steps points logarithmically spaced between 10start and 10end.

The output is a 1-D tensor of size steps

Parameters:
  • start (float) – the starting value for the set of points
  • end (float) – the ending value for the set of points
  • steps (int) – number of points to sample between start and end
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.logspace(start=-10, end=10, steps=5)

 1.0000e-10
 1.0000e-05
 1.0000e+00
 1.0000e+05
 1.0000e+10
[torch.FloatTensor of size 5]

>>> torch.logspace(start=0.1, end=1.0, steps=5)

  1.2589
  2.1135
  3.5481
  5.9566
 10.0000
[torch.FloatTensor of size 5]
torch.ones(*sizes, out=None) → Tensor

Returns a tensor filled with the scalar value 1, with the shape defined by the varargs sizes.

Parameters:
  • sizes (int...) – a set of integers defining the shape of the output tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.ones(2, 3)

 1  1  1
 1  1  1
[torch.FloatTensor of size 2x3]

>>> torch.ones(5)

 1
 1
 1
 1
 1
[torch.FloatTensor of size 5]
torch.ones_like(input, out=None) → Tensor

Returns a tensor filled with the scalar value 1, with the same size as input.

Parameters:
  • input (Tensor) – the size of input will determine size of the output tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> input = torch.FloatTensor(2, 3)
>>> torch.ones_like(input)

 1  1  1
 1  1  1
[torch.FloatTensor of size 2x3]
torch.arange(start=0, end, step=1, out=None) → Tensor

Returns a 1-D tensor of size endstartstep with values from the interval [start, end) taken with step step starting from start.

Parameters:
  • start (float) – the starting value for the set of points
  • end (float) – the ending value for the set of points
  • step (float) – the gap between each pair of adjacent points
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.arange(5)

 0
 1
 2
 3
 4
[torch.FloatTensor of size 5]

>>> torch.arange(1, 4)

 1
 2
 3
[torch.FloatTensor of size 3]

>>> torch.arange(1, 2.5, 0.5)

 1.0000
 1.5000
 2.0000
[torch.FloatTensor of size 3]
torch.range(start, end, step=1, out=None) → Tensor

Returns a 1-D tensor of size endstartstep+1 with values from start to end with step step. Step is the gap between two values in the tensor. xi+1=xi+step.

Warning

This function is deprecated in favor of torch.arange().

Parameters:
  • start (float) – the starting value for the set of points
  • end (float) – the ending value for the set of points
  • step (float) – the gap between each pair of adjacent points
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.range(1, 4)

 1
 2
 3
 4
[torch.FloatTensor of size 4]

>>> torch.range(1, 4, 0.5)

 1.0000
 1.5000
 2.0000
 2.5000
 3.0000
 3.5000
 4.0000
[torch.FloatTensor of size 7]
torch.zeros(*sizes, out=None) → Tensor

Returns a tensor filled with the scalar value 0, with the shape defined by the varargs sizes.

Parameters:
  • sizes (int...) – a set of integers defining the shape of the output tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.zeros(2, 3)

 0  0  0
 0  0  0
[torch.FloatTensor of size 2x3]

>>> torch.zeros(5)

 0
 0
 0
 0
 0
[torch.FloatTensor of size 5]
torch.zeros_like(input, out=None) → Tensor

Returns a tensor filled with the scalar value 0, with the same size as input.

Parameters:
  • input (Tensor) – the size of the input will determine the size of the output.
  • out (Tensor, optional) – the output tensor

Example:

>>> input = torch.FloatTensor(2, 3)
>>> torch.zeros_like(input)

 0  0  0
 0  0  0
[torch.FloatTensor of size 2x3]

Indexing, Slicing, Joining, Mutating Ops

torch.cat(seq, dim=0, out=None) → Tensor

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the cat dimension) or be empty.

torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk()

cat() can be best understood via examples.

Parameters:
  • seq (sequence of tensors) – any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
  • dim (int, optional) – the dimension over which the tensors are concatenated
  • out (Tensor, optional) – the output tensor

Example:

>>> x = torch.randn(2, 3)
>>> x

 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x3]

>>> torch.cat((x, x, x), 0)

 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 6x3]

>>> torch.cat((x, x, x), 1)

 0.5983 -0.0341  2.4918  0.5983 -0.0341  2.4918  0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735  1.5981 -0.5265 -0.8735  1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x9]
torch.chunk(tensor, chunks, dim=0)[source]

Splits a tensor into a specific number of chunks.

Parameters:
  • tensor (Tensor) – the tensor to split
  • chunks (int) – number of chunks to return
  • dim (int) – dimension along which to split the tensor
torch.gather(input, dim, index, out=None) → Tensor

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

If input is an n-dimensional tensor with size (x0,x1...,xi1,xi,xi+1,...,xn1) and dim = i, then index must be an n-dimensional tensor with size (x0,x1,...,xi1,y,xi+1,...,xn1) where y >= 1 and out will have the same size as index.

Parameters:
  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather
  • out (Tensor, optional) – the destination tensor

Example:

>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
 1  1
 4  3
[torch.FloatTensor of size 2x2]
torch.index_select(input, dim, index, out=None) → Tensor

Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor.

The returned tensor has the same number of dimensions as the original tensor (input). The dimth dimension has the same size as the length of index; other dimensions have the same size as in the original tensor.

Note

The returned tensor does not use the same storage as the original tensor. If out has a different shape than expected, we silently change it to the correct shape, reallocating the underlying storage if necessary.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension in which we index
  • index (LongTensor) – the 1-D tensor containing the indices to index
  • out (Tensor, optional) – the output tensor

Example:

>>> x = torch.randn(3, 4)
>>> x

 1.2045  2.4084  0.4001  1.1372
 0.5596  1.5677  0.6219 -0.7954
 1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 3x4]

>>> indices = torch.LongTensor([0, 2])
>>> torch.index_select(x, 0, indices)

 1.2045  2.4084  0.4001  1.1372
 1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 2x4]

>>> torch.index_select(x, 1, indices)

 1.2045  0.4001
 0.5596  0.6219
 1.3635 -0.5414
[torch.FloatTensor of size 3x2]
torch.masked_select(input, mask, out=None) → Tensor

Returns a new 1-D tensor which indexes the input tensor according to the binary mask mask which is a ByteTensor.

The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.

Note

The returned tensor does not use the same storage as the original tensor

Parameters:
  • input (Tensor) – the input data
  • mask (ByteTensor) – the tensor containing the binary mask to index with
  • out (Tensor, optional) – the output tensor

Example:

>>> x = torch.randn(3, 4)
>>> x

 1.2045  2.4084  0.4001  1.1372
 0.5596  1.5677  0.6219 -0.7954
 1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 3x4]

>>> mask = x.ge(0.5)
>>> mask

 1  1  0  1
 1  1  1  0
 1  0  0  0
[torch.ByteTensor of size 3x4]

>>> torch.masked_select(x, mask)

 1.2045
 2.4084
 1.1372
 0.5596
 1.5677
 0.6219
 1.3635
[torch.FloatTensor of size 7]
torch.nonzero(input, out=None) → LongTensor

Returns a tensor containing the indices of all non-zero elements of input. Each row in the result contains the indices of a non-zero element in input.

If input has n dimensions, then the resulting indices tensor out is of size (z×n), where z is the total number of non-zero elements in the input tensor.

Parameters:
  • input (Tensor) – the input tensor
  • out (LongTensor, optional) – the output tensor containing indices

Example:

>>> torch.nonzero(torch.Tensor([1, 1, 1, 0, 1]))

 0
 1
 2
 4
[torch.LongTensor of size 4x1]

>>> torch.nonzero(torch.Tensor([[0.6, 0.0, 0.0, 0.0],
...                             [0.0, 0.4, 0.0, 0.0],
...                             [0.0, 0.0, 1.2, 0.0],
...                             [0.0, 0.0, 0.0,-0.4]]))

 0  0
 1  1
 2  2
 3  3
[torch.LongTensor of size 4x2]
torch.split(tensor, split_size, dim=0)[source]

Splits the tensor into chunks all of size split_size (if possible).

Last chunk will be smaller if the tensor size along a given dimension is not divisible by :attr`split_size`.

Parameters:
  • tensor (Tensor) – the tensor to split
  • split_size (int) – size of a single chunk
  • dim (int) – dimension along which to split the tensor
torch.squeeze(input, dim=None, out=None)

Returns a tensor with all the dimensions of input of size 1 removed.

For example, if input is of shape: (Aimes1imesBimesCimes1imesD) then the out tensor will be of shape: (AimesBimesCimesD).

When dim is given, a squeeze operation is done only in the given dimension. If input is of shape: (Aimes1imesB), squeeze(input, 0) leaves the tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape (AimesB).

Note

As an exception to the above, a 1-dimensional tensor of size 1 will not have its dimensions changed.

Note

The returned tensor shares the storage with the input tensor, so changing the contents of one will change the contents of the other.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int, optional) – if given, the input will be squeezed only in this dimension
  • out (Tensor, optional) – the output tensor

Example:

>>> x = torch.zeros(2,1,2,1,2)
>>> x.size()
(2L, 1L, 2L, 1L, 2L)
>>> y = torch.squeeze(x)
>>> y.size()
(2L, 2L, 2L)
>>> y = torch.squeeze(x, 0)
>>> y.size()
(2L, 1L, 2L, 1L, 2L)
>>> y = torch.squeeze(x, 1)
>>> y.size()
(2L, 2L, 1L, 2L)
torch.stack(sequence, dim=0, out=None)[source]

Concatenates sequence of tensors along a new dimension.

All tensors need to be of the same size.

Parameters:
  • sequence (Sequence) – sequence of tensors to concatenate
  • dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated tensors (inclusive)
torch.t(input, out=None) → Tensor

Expects input to be a matrix (2-D tensor) and transposes dimensions 0 and 1.

Can be seen as a short-hand function for transpose(input, 0, 1)

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> x = torch.randn(2, 3)
>>> x

 0.4834  0.6907  1.3417
-0.1300  0.5295  0.2321
[torch.FloatTensor of size 2x3]

>>> torch.t(x)

 0.4834 -0.1300
 0.6907  0.5295
 1.3417  0.2321
[torch.FloatTensor of size 3x2]
torch.take()

take(input, indices) -> Tensor

Returns a new tensor with the elements of input at the given indices. The input tensor is treated as if it were viewed as a 1-D tensor. The result takes the same shape as the indices.

Parameters:
  • input (Tensor) – the input tensor
  • indices (LongTensor) – the indices into tensor

Example:

>>> src = torch.Tensor([[4, 3, 5],
...                     [6, 7, 8]])
>>> torch.take(src, torch.LongTensor([0, 2, 5]))
 4
 5
 8
[torch.FloatTensor of size 3]
torch.transpose(input, dim0, dim1, out=None) → Tensor

Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.

The resulting out tensor shares it’s underlying storage with the input tensor, so changing the content of one would change the content of the other.

Parameters:
  • input (Tensor) – the input tensor
  • dim0 (int) – the first dimension to be transposed
  • dim1 (int) – the second dimension to be transposed

Example:

>>> x = torch.randn(2, 3)
>>> x

 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x3]

>>> torch.transpose(x, 0, 1)

 0.5983  1.5981
-0.0341 -0.5265
 2.4918 -0.8735
[torch.FloatTensor of size 3x2]
torch.unbind(tensor, dim=0)[source]

Removes a tensor dimension.

Returns a tuple of all slices along a given dimension, already without it.

Parameters:
  • tensor (Tensor) – the tensor to unbind
  • dim (int) – dimension to remove
torch.unsqueeze(input, dim, out=None)

Returns a new tensor with a dimension of size one inserted at the specified position.

The returned tensor shares the same underlying data with this tensor.

A negative dim value can be used and will correspond to dim+input.dim()+1

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the index at which to insert the singleton dimension
  • out (Tensor, optional) – the output tensor

Example

>>> x = torch.Tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
 1  2  3  4
[torch.FloatTensor of size 1x4]
>>> torch.unsqueeze(x, 1)
 1
 2
 3
 4
[torch.FloatTensor of size 4x1]

Random sampling

torch.manual_seed(seed)[source]

Sets the seed for generating random numbers. And returns a torch._C.Generator object.

Parameters:seed (int or long) – The desired seed.
torch.initial_seed()[source]

Returns the initial seed for generating random numbers as a python long.

torch.get_rng_state()[source]

Returns the random number generator state as a ByteTensor.

torch.set_rng_state(new_state)[source]

Sets the random number generator state.

Parameters:new_state (torch.ByteTensor) – The desired state
torch.default_generator = <torch._C.Generator object>
torch.bernoulli(input, out=None) → Tensor

Draws binary random numbers (0 or 1) from a Bernoulli distribution.

The input tensor should be a tensor containing probabilities to be used for drawing the binary random number. Hence, all values in input have to be in the range: 0inputi1

The i-th element of the output tensor will draw a value 1 according to the i-th probability value given in input.

The returned out tensor only has values 0 or 1 and is of the same shape as input

Parameters:
  • input (Tensor) – the input tensor of probability values for the Bernoulli distribution
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.Tensor(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1]
>>> a

 0.7544  0.8140  0.9842
 0.5282  0.0595  0.6445
 0.1925  0.9553  0.9732
[torch.FloatTensor of size 3x3]

>>> torch.bernoulli(a)

 1  1  1
 0  0  1
 0  1  1
[torch.FloatTensor of size 3x3]

>>> a = torch.ones(3, 3) # probability of drawing "1" is 1
>>> torch.bernoulli(a)

 1  1  1
 1  1  1
 1  1  1
[torch.FloatTensor of size 3x3]

>>> a = torch.zeros(3, 3) # probability of drawing "1" is 0
>>> torch.bernoulli(a)

 0  0  0
 0  0  0
 0  0  0
[torch.FloatTensor of size 3x3]
torch.multinomial(input, num_samples, replacement=False, out=None) → LongTensor

Returns a tensor where each row contains num_samples indices sampled from the multinomial probability distribution located in the corresponding row of tensor input.

Note

The rows of input do not need to sum to one (in which case we use the values as weights), but must be non-negative and have a non-zero sum.

Indices are ordered from left to right according to when each was sampled (first samples are placed in first column).

If input is a vector, out is a vector of size num_samples.

If input is a matrix with m rows, out is an matrix of shape m × n.

If replacement is True, samples are drawn with replacement.

If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.

This implies the constraint that num_samples must be lower than input length (or number of columns of input if it is a matrix).

Parameters:
  • input (Tensor) – the input tensor containing probabilities
  • num_samples (int) – number of samples to draw
  • replacement (bool, optional) – whether to draw with replacement or not
  • out (Tensor, optional) – the output tensor

Example:

>>> weights = torch.Tensor([0, 10, 3, 0]) # create a tensor of weights
>>> torch.multinomial(weights, 4)

 1
 2
 0
 0
[torch.LongTensor of size 4]

>>> torch.multinomial(weights, 4, replacement=True)

 1
 2
 1
 2
[torch.LongTensor of size 4]
torch.normal()
torch.normal(means, std, out=None)

Returns a tensor of random numbers drawn from separate normal distributions who’s mean and standard deviation are given.

The means is a tensor with the mean of each output element’s normal distribution

The std is a tensor with the standard deviation of each output element’s normal distribution

The shapes of means and std don’t need to match. The total number of elements in each tensor need to be the same.

Note

When the shapes do not match, the shape of means is used as the shape for the returned output tensor

Parameters:
  • means (Tensor) – the tensor of per-element means
  • std (Tensor) – the tensor of per-element standard deviations
  • out (Tensor, optional) – the output tensor

Example:

torch.normal(means=torch.arange(1, 11), std=torch.arange(1, 0, -0.1))

 1.5104
 1.6955
 2.4895
 4.9185
 4.9895
 6.9155
 7.3683
 8.1836
 8.7164
 9.8916
[torch.FloatTensor of size 10]
torch.normal(mean=0.0, std, out=None)

Similar to the function above, but the means are shared among all drawn elements.

Parameters:
  • means (float, optional) – the mean for all distributions
  • std (Tensor) – the tensor of per-element standard deviations
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.normal(mean=0.5, std=torch.arange(1, 6))

  0.5723
  0.0871
 -0.3783
 -2.5689
 10.7893
[torch.FloatTensor of size 5]
torch.normal(means, std=1.0, out=None)

Similar to the function above, but the standard-deviations are shared among all drawn elements.

Parameters:
  • means (Tensor) – the tensor of per-element means
  • std (float, optional) – the standard deviation for all distributions
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.normal(means=torch.arange(1, 6))

 1.1681
 2.8884
 3.7718
 2.5616
 4.2500
[torch.FloatTensor of size 5]
torch.rand(*sizes, out=None) → Tensor

Returns a tensor filled with random numbers from a uniform distribution on the interval [0,1)

The shape of the tensor is defined by the varargs sizes.

Parameters:
  • sizes (int...) – a set of ints defining the shape of the output tensor.
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.rand(4)

 0.9193
 0.3347
 0.3232
 0.7715
[torch.FloatTensor of size 4]

>>> torch.rand(2, 3)

 0.5010  0.5140  0.0719
 0.1435  0.5636  0.0538
[torch.FloatTensor of size 2x3]
torch.randn(*sizes, out=None) → Tensor

Returns a tensor filled with random numbers from a normal distribution with zero mean and variance of one.

The shape of the tensor is defined by the varargs sizes.

Parameters:
  • sizes (int...) – a set of ints defining the shape of the output tensor.
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.randn(4)

-0.1145
 0.0094
-1.1717
 0.9846
[torch.FloatTensor of size 4]

>>> torch.randn(2, 3)

 1.4339  0.3351 -1.0999
 1.5458 -0.9643 -0.3558
[torch.FloatTensor of size 2x3]
torch.randperm(n, out=None) → LongTensor

Returns a random permutation of integers from 0 to n - 1.

Parameters:n (int) – the upper bound (exclusive)

Example:

>>> torch.randperm(4)

 2
 1
 3
 0
[torch.LongTensor of size 4]

In-place random sampling

There are a few more in-place random sampling functions defined on Tensors as well. Click through to refer to their documentation:

Serialization

torch.save(obj, f, pickle_module=pickle, pickle_protocol=2)[source]

Saves an object to a disk file.

See also: Recommended approach for saving a model

Parameters:
  • obj – saved object
  • f – a file-like object (has to implement fileno that returns a file descriptor) or a string containing a file name
  • pickle_module – module used for pickling metadata and objects
  • pickle_protocol – can be specified to override the default protocol
torch.load(f, map_location=None, pickle_module=pickle)[source]

Loads an object saved with torch.save() from a file.

torch.load uses Python’s unpickling facilities but treats storages, which underlie tensors, specially. They are first deserialized on the CPU and are then moved to the device they were saved from. If this fails (e.g. because the run time system doesn’t have certain devices), an exception is raised. However, storages can be dynamically remapped to an alternative set of devices using the map_location argument.

If map_location is a callable, it will be called once for each serialized storage with two arguments: storage and location. The storage argument will be the initial deserialization of the storage, residing on the CPU. Each serialized storage has a location tag associated with it which identifies the device it was saved from, and this tag is the second argument passed to map_location. The builtin location tags are ‘cpu’ for CPU tensors and ‘cuda:device_id’ (e.g. ‘cuda:2’) for CUDA tensors. map_location should return either None or a storage. If map_location returns a storage, it will be used as the final deserialized object, already moved to the right device. Otherwise, torch.load will fall back to the default behavior, as if map_location wasn’t specified.

If map_location is a string, it should be a device tag, where all tensors should be loaded.

Otherwise, if map_location is a dict, it will be used to remap location tags appearing in the file (keys), to ones that specify where to put the storages (values).

User extensions can register their own location tags and tagging and deserialization methods using register_package.

Parameters:
  • f – a file-like object (has to implement fileno that returns a file descriptor, and must implement seek), or a string containing a file name
  • map_location – a function, string or a dict specifying how to remap storage locations
  • pickle_module – module used for unpickling metadata and objects (has to match the pickle_module used to serialize file)

Example

>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location='cpu')
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

Parallelism

torch.get_num_threads() → int

Gets the number of OpenMP threads used for parallelizing CPU operations

torch.set_num_threads(int)

Sets the number of OpenMP threads used for parallelizing CPU operations

Math operations

Pointwise Ops

torch.abs(input, out=None) → Tensor

Computes the element-wise absolute value of the given input tensor.

Example:

>>> torch.abs(torch.FloatTensor([-1, -2, 3]))
FloatTensor([1, 2, 3])
torch.acos(input, out=None) → Tensor

Returns a new tensor with the arccosine of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

-0.6366
 0.2718
 0.4469
 1.3122
[torch.FloatTensor of size 4]

>>> torch.acos(a)
 2.2608
 1.2956
 1.1075
    nan
[torch.FloatTensor of size 4]
torch.add()
torch.add(input, value, out=None)

Adds the scalar value to each element of the input input and returns a new resulting tensor.

out=input+value

If input is of type FloatTensor or DoubleTensor, value must be a real number, otherwise it should be an integer.

Parameters:
  • input (Tensor) – the input tensor
  • value (Number) – the number to be added to each element of input
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 0.4050
-1.2227
 1.8688
-0.4185
[torch.FloatTensor of size 4]

>>> torch.add(a, 20)

 20.4050
 18.7773
 21.8688
 19.5815
[torch.FloatTensor of size 4]
torch.add(input, value=1, other, out=None)

Each element of the tensor other is multiplied by the scalar value and added to each element of the tensor input. The resulting tensor is returned.

The shapes of input and other must be broadcastable.

out=input+value×other

If other is of type FloatTensor or DoubleTensor, value must be a real number, otherwise it should be an integer.

Parameters:
  • input (Tensor) – the first input tensor
  • value (Number) – the scalar multiplier for other
  • other (Tensor) – the second input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> import torch
>>> a = torch.randn(4)
>>> a

-0.9310
 2.0330
 0.0852
-0.2941
[torch.FloatTensor of size 4]

>>> b = torch.randn(2, 2)
>>> b

 1.0663  0.2544
-0.1513  0.0749
[torch.FloatTensor of size 2x2]

>>> torch.add(a, 10, b)
 9.7322
 4.5770
-1.4279
 0.4552
[torch.FloatTensor of size 4]
torch.addcdiv(tensor, value=1, tensor1, tensor2, out=None) → Tensor

Performs the element-wise division of tensor1 by tensor2, multiply the result by the scalar value and add it to tensor.

outi=tensori+value×tensor1itensor2i

The shapes of tensor, tensor1, and tensor2 must be broadcastable.

For inputs of type FloatTensor or DoubleTensor, value must be a real number, otherwise an integer.

Parameters:
  • tensor (Tensor) – the tensor to be added
  • value (Number, optional) – multiplier for tensor1./tensor2
  • tensor1 (Tensor) – the numerator tensor
  • tensor2 (Tensor) – the denominator tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> t = torch.randn(2, 3)
>>> t1 = torch.randn(1, 6)
>>> t2 = torch.randn(6, 1)
>>> torch.addcdiv(t, 0.1, t1, t2)

 0.0122 -0.0188 -0.2354
 0.7396 -1.5721  1.2878
[torch.FloatTensor of size 2x3]
torch.addcmul(tensor, value=1, tensor1, tensor2, out=None) → Tensor

Performs the element-wise multiplication of tensor1 by tensor2, multiply the result by the scalar value and add it to tensor.

outi=tensori+value×tensor1i×tensor2i

The shapes of tensor, tensor1, and tensor2 must be broadcastable.

For inputs of type FloatTensor or DoubleTensor, value must be a real number, otherwise an integer.

Parameters:
  • tensor (Tensor) – the tensor to be added
  • value (Number, optional) – multiplier for tensor1.tensor2
  • tensor1 (Tensor) – the tensor to be multiplied
  • tensor2 (Tensor) – the tensor to be multiplied
  • out (Tensor, optional) – the output tensor

Example:

>>> t = torch.randn(2, 3)
>>> t1 = torch.randn(1, 6)
>>> t2 = torch.randn(6, 1)
>>> torch.addcmul(t, 0.1, t1, t2)

 0.0122 -0.0188 -0.2354
 0.7396 -1.5721  1.2878
[torch.FloatTensor of size 2x3]
torch.asin(input, out=None) → Tensor

Returns a new tensor with the arcsine of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a
-0.6366
 0.2718
 0.4469
 1.3122
[torch.FloatTensor of size 4]

>>> torch.asin(a)
-0.6900
 0.2752
 0.4633
    nan
[torch.FloatTensor of size 4]
torch.atan(input, out=None) → Tensor

Returns a new tensor with the arctangent of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a
-0.6366
 0.2718
 0.4469
 1.3122
[torch.FloatTensor of size 4]

>>> torch.atan(a)
-0.5669
 0.2653
 0.4203
 0.9196
[torch.FloatTensor of size 4]
torch.atan2(input1, input2, out=None) → Tensor

Returns a new tensor with the arctangent of the elements of input1 and input2.

The shapes of input1 and input2 must be broadcastable.

Parameters:
  • input1 (Tensor) – the first input tensor
  • input2 (Tensor) – the second input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a
-0.6366
 0.2718
 0.4469
 1.3122
[torch.FloatTensor of size 4]

>>> torch.atan2(a, torch.randn(4))
-2.4167
 2.9755
 0.9363
 1.6613
[torch.FloatTensor of size 4]
torch.ceil(input, out=None) → Tensor

Returns a new tensor with the ceil of the elements of input, the smallest integer greater than or equal to each element.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 1.3869
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]

>>> torch.ceil(a)

 2
 1
-0
-0
[torch.FloatTensor of size 4]
torch.clamp(input, min, max, out=None) → Tensor

Clamp all elements in input into the range [min, max] and return a resulting tensor:

yi={minif xi<minxiif minximaxmaxif xi>max

If input is of type FloatTensor or DoubleTensor, args min and max must be real numbers, otherwise they should be integers.

Parameters:
  • input (Tensor) – the input tensor
  • min (Number) – lower-bound of the range to be clamped to
  • max (Number) – upper-bound of the range to be clamped to
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 1.3869
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]

>>> torch.clamp(a, min=-0.5, max=0.5)

 0.5000
 0.3912
-0.5000
-0.5000
[torch.FloatTensor of size 4]
torch.clamp(input, *, min, out=None) → Tensor

Clamps all elements in input to be larger or equal min.

If input is of type FloatTensor or DoubleTensor, value should be a real number, otherwise it should be an integer.

Parameters:
  • input (Tensor) – the input tensor
  • value (Number) – minimal value of each element in the output
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 1.3869
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]

>>> torch.clamp(a, min=0.5)

 1.3869
 0.5000
 0.5000
 0.5000
[torch.FloatTensor of size 4]
torch.clamp(input, *, max, out=None) → Tensor

Clamps all elements in input to be smaller or equal max.

If input is of type FloatTensor or DoubleTensor, value should be a real number, otherwise it should be an integer.

Parameters:
  • input (Tensor) – the input tensor
  • value (Number) – maximal value of each element in the output
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 1.3869
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]

>>> torch.clamp(a, max=0.5)

 0.5000
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]
torch.cos(input, out=None) → Tensor

Returns a new tensor with the cosine of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a
-0.6366
 0.2718
 0.4469
 1.3122
[torch.FloatTensor of size 4]

>>> torch.cos(a)
 0.8041
 0.9633
 0.9018
 0.2557
[torch.FloatTensor of size 4]
torch.cosh(input, out=None) → Tensor

Returns a new tensor with the hyperbolic cosine of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a
-0.6366
 0.2718
 0.4469
 1.3122
[torch.FloatTensor of size 4]

>>> torch.cosh(a)
 1.2095
 1.0372
 1.1015
 1.9917
[torch.FloatTensor of size 4]
torch.div()
torch.div(input, value, out=None)

Divides each element of the input input with the scalar value and returns a new resulting tensor.

outi=inputivalue

If input is of type FloatTensor or DoubleTensor, value should be a real number, otherwise it should be an integer

Parameters:
  • input (Tensor) – the input tensor
  • value (Number) – the number to be divided to each element of input
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(5)
>>> a

-0.6147
-1.1237
-0.1604
-0.6853
 0.1063
[torch.FloatTensor of size 5]

>>> torch.div(a, 0.5)

-1.2294
-2.2474
-0.3208
-1.3706
 0.2126
[torch.FloatTensor of size 5]
torch.div(input, other, out=None)

Each element of the tensor input is divided by each element of the tensor other. The resulting tensor is returned. The shapes of input and other must be broadcastable.

outi=inputiotheri
Parameters:
  • input (Tensor) – the numerator tensor
  • other (Tensor) – the denominator tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4,4)
>>> a

-0.1810  0.4017  0.2863 -0.1013
 0.6183  2.0696  0.9012 -1.5933
 0.5679  0.4743 -0.0117 -0.1266
-0.1213  0.9629  0.2682  1.5968
[torch.FloatTensor of size 4x4]

>>> b = torch.randn(8, 2)
>>> b

 0.8774  0.7650
 0.8866  1.4805
-0.6490  1.1172
 1.4259 -0.8146
 1.4633 -0.1228
 0.4643 -0.6029
 0.3492  1.5270
 1.6103 -0.6291
[torch.FloatTensor of size 8x2]

>>> torch.div(a, b)

-0.2062  0.5251  0.3229 -0.0684
-0.9528  1.8525  0.6320  1.9559
 0.3881 -3.8625 -0.0253  0.2099
-0.3473  0.6306  0.1666 -2.5381
[torch.FloatTensor of size 4x4]
torch.erf(tensor, out=None) → Tensor

Computes the error function of each element.

Example:

>>> torch.erf(torch.Tensor([0, -1., 10.]))
torch.FloatTensor([0., -0.8427, 1.])
torch.erfinv(tensor, out=None) → Tensor

Computes the inverse error function of each element.

Example:

>>> torch.erfinv(torch.Tensor([0, 0.5., -1.]))
torch.FloatTensor([0., 0.4769, -inf])
torch.exp(tensor, out=None) → Tensor

Computes the exponential of each element.

Example:

>>> torch.exp(torch.Tensor([0, math.log(2)]))
torch.FloatTensor([1, 2])
torch.floor(input, out=None) → Tensor

Returns a new tensor with the floor of the elements of input, the largest integer less than or equal to each element.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 1.3869
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]

>>> torch.floor(a)

 1
 0
-1
-1
[torch.FloatTensor of size 4]
torch.fmod(input, divisor, out=None) → Tensor

Computes the element-wise remainder of division.

The dividend and divisor may contain both for integer and floating point numbers. The remainder has the same sign as the dividend input.

When divisor is a tensor, the shapes of input and divisor must be broadcastable.

Parameters:
  • input (Tensor) – the dividend
  • divisor (Tensor or float) – the divisor, which may be either a number or a tensor of the same shape as the dividend
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.fmod(torch.Tensor([-3, -2, -1, 1, 2, 3]), 2)
torch.FloatTensor([-1, -0, -1, 1, 0, 1])
>>> torch.fmod(torch.Tensor([1, 2, 3, 4, 5]), 1.5)
torch.FloatTensor([1.0, 0.5, 0.0, 1.0, 0.5])

See also

torch.remainder(), which computes the element-wise remainder of division equivalently to Python’s % operator

torch.frac(tensor, out=None) → Tensor

Computes the fractional portion of each element in tensor.

Example:

>>> torch.frac(torch.Tensor([1, 2.5, -3.2])
torch.FloatTensor([0, 0.5, -0.2])
torch.lerp(start, end, weight, out=None)

Does a linear interpolation of two tensors start and end based on a scalar weight and returns the resulting out tensor.

outi=starti+weight×(endistarti)

The shapes of start and end must be broadcastable.

Parameters:
  • start (Tensor) – the tensor with the starting points
  • end (Tensor) – the tensor with the ending points
  • weight (float) – the weight for the interpolation formula
  • out (Tensor, optional) – the output tensor

Example:

>>> start = torch.arange(1, 5)
>>> end = torch.Tensor(4).fill_(10)
>>> start

 1
 2
 3
 4
[torch.FloatTensor of size 4]

>>> end

 10
 10
 10
 10
[torch.FloatTensor of size 4]

>>> torch.lerp(start, end, 0.5)

 5.5000
 6.0000
 6.5000
 7.0000
[torch.FloatTensor of size 4]
torch.log(input, out=None) → Tensor

Returns a new tensor with the natural logarithm of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(5)
>>> a

-0.4183
 0.3722
-0.3091
 0.4149
 0.5857
[torch.FloatTensor of size 5]

>>> torch.log(a)

    nan
-0.9883
    nan
-0.8797
-0.5349
[torch.FloatTensor of size 5]
torch.log1p(input, out=None) → Tensor

Returns a new tensor with the natural logarithm of (1 + input).

yi=log(xi+1)

Note

This function is more accurate than torch.log() for small values of input

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(5)
>>> a

-0.4183
 0.3722
-0.3091
 0.4149
 0.5857
[torch.FloatTensor of size 5]

>>> torch.log1p(a)

-0.5418
 0.3164
-0.3697
 0.3471
 0.4611
[torch.FloatTensor of size 5]
torch.mul()
torch.mul(input, value, out=None)

Multiplies each element of the input input with the scalar value and returns a new resulting tensor.

outi=value×inputi

If input is of type FloatTensor or DoubleTensor, value should be a real number, otherwise it should be an integer

Parameters:
  • input (Tensor) – the input tensor
  • value (Number) – the number to be multiplied to each element of input
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(3)
>>> a

-0.9374
-0.5254
-0.6069
[torch.FloatTensor of size 3]

>>> torch.mul(a, 100)

-93.7411
-52.5374
-60.6908
[torch.FloatTensor of size 3]
torch.mul(input, other, out=None)

Each element of the tensor input is multiplied by each element of the Tensor other. The resulting tensor is returned.

The shapes of input and other must be broadcastable.

outi=inputi×otheri
Parameters:
  • input (Tensor) – the first multiplicand tensor
  • other (Tensor) – the second multiplicand tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4,4)
>>> a

-0.7280  0.0598 -1.4327 -0.5825
-0.1427 -0.0690  0.0821 -0.3270
-0.9241  0.5110  0.4070 -1.1188
-0.8308  0.7426 -0.6240 -1.1582
[torch.FloatTensor of size 4x4]

>>> b = torch.randn(2, 8)
>>> b

 0.0430 -1.0775  0.6015  1.1647 -0.6549  0.0308 -0.1670  1.0742
-1.2593  0.0292 -0.0849  0.4530  1.2404 -0.4659 -0.1840  0.5974
[torch.FloatTensor of size 2x8]

>>> torch.mul(a, b)

-0.0313 -0.0645 -0.8618 -0.6784
 0.0934 -0.0021 -0.0137 -0.3513
 1.1638  0.0149 -0.0346 -0.5068
-1.0304 -0.3460  0.1148 -0.6919
[torch.FloatTensor of size 4x4]
torch.neg(input, out=None) → Tensor

Returns a new tensor with the negative of the elements of input.

out=1×input
Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(5)
>>> a

-0.4430
 1.1690
-0.8836
-0.4565
 0.2968
[torch.FloatTensor of size 5]

>>> torch.neg(a)

 0.4430
-1.1690
 0.8836
 0.4565
-0.2968
[torch.FloatTensor of size 5]
torch.pow()
torch.pow(input, exponent, out=None)

Takes the power of each element in input with exponent and returns a tensor with the result.

exponent can be either a single float number or a Tensor with the same number of elements as input.

When exponent is a scalar value, the operation applied is:

outi=xexponenti

When exponent is a tensor, the operation applied is:

outi=xexponentii

When exponent is a tensor, the shapes of input and exponent must be broadcastable.

Parameters:
  • input (Tensor) – the input tensor
  • exponent (float or tensor) – the exponent value
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

-0.5274
-0.8232
-2.1128
 1.7558
[torch.FloatTensor of size 4]

>>> torch.pow(a, 2)

 0.2781
 0.6776
 4.4640
 3.0829
[torch.FloatTensor of size 4]

>>> exp = torch.arange(1, 5)
>>> a = torch.arange(1, 5)
>>> a

 1
 2
 3
 4
[torch.FloatTensor of size 4]

>>> exp

 1
 2
 3
 4
[torch.FloatTensor of size 4]

>>> torch.pow(a, exp)

   1
   4
  27
 256
[torch.FloatTensor of size 4]
torch.pow(base, input, out=None)

base is a scalar float value, and input is a tensor. The returned tensor out is of the same shape as input

The operation applied is:

outi=baseinputi
Parameters:
  • base (float) – the scalar base value for the power operation
  • input (Tensor) – the exponent tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> exp = torch.arange(1, 5)
>>> base = 2
>>> torch.pow(base, exp)

  2
  4
  8
 16
[torch.FloatTensor of size 4]
torch.reciprocal(input, out=None) → Tensor

Returns a new tensor with the reciprocal of the elements of input, i.e. x1=1x.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 1.3869
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]

>>> torch.reciprocal(a)

 0.7210
 2.5565
-1.1583
-1.8289
[torch.FloatTensor of size 4]
torch.remainder(input, divisor, out=None) → Tensor

Computes the element-wise remainder of division.

The divisor and dividend may contain both for integer and floating point numbers. The remainder has the same sign as the divisor.

When divisor is a tensor, the shapes of input and divisor must be broadcastable.

Parameters:
  • input (Tensor) – the dividend
  • divisor (Tensor or float) – the divisor that may be either a number or a Tensor of the same shape as the dividend
  • out (Tensor, optional) – the output tensor

Example:

>>> torch.remainder(torch.Tensor([-3, -2, -1, 1, 2, 3]), 2)
torch.FloatTensor([1, 0, 1, 1, 0, 1])
>>> torch.remainder(torch.Tensor([1, 2, 3, 4, 5]), 1.5)
torch.FloatTensor([1.0, 0.5, 0.0, 1.0, 0.5])

See also

torch.fmod(), which computes the element-wise remainder of division equivalently to the C library function fmod()

torch.round(input, out=None) → Tensor

Returns a new tensor with each of the elements of input rounded to the closest integer.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 1.2290
 1.3409
-0.5662
-0.0899
[torch.FloatTensor of size 4]

>>> torch.round(a)

 1
 1
-1
-0
[torch.FloatTensor of size 4]
torch.rsqrt(input, out=None) → Tensor

Returns a new tensor with the reciprocal of the square-root of each of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 1.2290
 1.3409
-0.5662
-0.0899
[torch.FloatTensor of size 4]

>>> torch.rsqrt(a)

 0.9020
 0.8636
    nan
    nan
[torch.FloatTensor of size 4]
torch.sigmoid(input, out=None) → Tensor

Returns a new tensor with the sigmoid of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

-0.4972
 1.3512
 0.1056
-0.2650
[torch.FloatTensor of size 4]

>>> torch.sigmoid(a)

 0.3782
 0.7943
 0.5264
 0.4341
[torch.FloatTensor of size 4]
torch.sign(input, out=None) → Tensor

Returns a new tensor with the sign of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a
-0.6366
 0.2718
 0.4469
 1.3122
[torch.FloatTensor of size 4]

>>> torch.sign(a)

-1
 1
 1
 1
torch.sin(input, out=None) → Tensor

Returns a new tensor with the sine of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a
-0.6366
 0.2718
 0.4469
 1.3122
[torch.FloatTensor of size 4]

>>> torch.sin(a)
-0.5944
 0.2684
 0.4322
 0.9667
[torch.FloatTensor of size 4]
torch.sinh(input, out=None) → Tensor

Returns a new tensor with the hyperbolic sine of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a
-0.6366
 0.2718
 0.4469
 1.3122
[torch.FloatTensor of size 4]

>>> torch.sinh(a)
-0.6804
 0.2751
 0.4619
 1.7225
[torch.FloatTensor of size 4]
torch.sqrt(input, out=None) → Tensor

Returns a new tensor with the square-root of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 1.2290
 1.3409
-0.5662
-0.0899
[torch.FloatTensor of size 4]

>>> torch.sqrt(a)

 1.1086
 1.1580
    nan
    nan
[torch.FloatTensor of size 4]
torch.tan(input, out=None) → Tensor

Returns a new tensor with the tangent of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a
-0.6366
 0.2718
 0.4469
 1.3122
[torch.FloatTensor of size 4]

>>> torch.tan(a)
-0.7392
 0.2786
 0.4792
 3.7801
[torch.FloatTensor of size 4]
torch.tanh(input, out=None) → Tensor

Returns a new tensor with the hyperbolic tangent of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a
-0.6366
 0.2718
 0.4469
 1.3122
[torch.FloatTensor of size 4]

>>> torch.tanh(a)
-0.5625
 0.2653
 0.4193
 0.8648
[torch.FloatTensor of size 4]
torch.trunc(input, out=None) → Tensor

Returns a new tensor with the truncated integer values of the elements of input.

Parameters:
  • input (Tensor) – the input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

-0.4972
 1.3512
 0.1056
-0.2650
[torch.FloatTensor of size 4]

>>> torch.trunc(a)

-0
 1
 0
-0
[torch.FloatTensor of size 4]

Reduction Ops

torch.cumprod(input, dim, out=None) → Tensor

Returns the cumulative product of elements of input in the dimension dim.

For example, if input is a vector of size N, the result will also be a vector of size N, with elements.

yi=x1×x2×x3××xi
Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension to do the operation over
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(10)
>>> a

 1.1148
 1.8423
 1.4143
-0.4403
 1.2859
-1.2514
-0.4748
 1.1735
-1.6332
-0.4272
[torch.FloatTensor of size 10]

>>> torch.cumprod(a, dim=0)

 1.1148
 2.0537
 2.9045
-1.2788
-1.6444
 2.0578
-0.9770
-1.1466
 1.8726
-0.8000
[torch.FloatTensor of size 10]

>>> a[5] = 0.0
>>> torch.cumprod(a, dim=0)

 1.1148
 2.0537
 2.9045
-1.2788
-1.6444
-0.0000
 0.0000
 0.0000
-0.0000
 0.0000
[torch.FloatTensor of size 10]
torch.cumsum(input, dim, out=None) → Tensor

Returns the cumulative sum of elements of input in the dimension dim.

For example, if input is a vector of size N, the result will also be a vector of size N, with elements.

yi=x1+x2+x3++xi
Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension to do the operation over
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(10)
>>> a

-0.6039
-0.2214
-0.3705
-0.0169
 1.3415
-0.1230
 0.9719
 0.6081
-0.1286
 1.0947
[torch.FloatTensor of size 10]

>>> torch.cumsum(a, dim=0)

-0.6039
-0.8253
-1.1958
-1.2127
 0.1288
 0.0058
 0.9777
 1.5858
 1.4572
 2.5519
[torch.FloatTensor of size 10]
torch.dist(input, other, p=2) → float

Returns the p-norm of (input - other)

The shapes of input and other must be broadcastable.

Parameters:
  • input (Tensor) – the input tensor
  • other (Tensor) – the Right-hand-side input tensor
  • p (float, optional) – the norm to be computed

Example:

>>> x = torch.randn(4)
>>> x

 0.2505
-0.4571
-0.3733
 0.7807
[torch.FloatTensor of size 4]

>>> y = torch.randn(4)
>>> y

 0.7782
-0.5185
 1.4106
-2.4063
[torch.FloatTensor of size 4]

>>> torch.dist(x, y, 3.5)
3.302832063224223
>>> torch.dist(x, y, 3)
3.3677282206393286
>>> torch.dist(x, y, 0)
inf
>>> torch.dist(x, y, 1)
5.560028076171875
torch.mean()
torch.mean(input) → float

Returns the mean value of all elements in the input tensor.

Parameters:input (Tensor) – the input tensor

Example:

>>> a = torch.randn(1, 3)
>>> a

-0.2946 -0.9143  2.1809
[torch.FloatTensor of size 1x3]

>>> torch.mean(a)
0.32398951053619385
torch.mean(input, dim, keepdim=False, out=None) → Tensor

Returns the mean value of each row of the input tensor in the given dimension dim.

If keepdim is True, the output tensor is of the same size as input except in the dimension dim where it is of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensor having 1 fewer dimension.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension to reduce
  • keepdim (bool, optional) – whether the output tensor has dim retained or not
  • out (Tensor) – the output tensor

Example:

>>> a = torch.randn(4, 4)
>>> a

-1.2738 -0.3058  0.1230 -1.9615
 0.8771 -0.5430 -0.9233  0.9879
 1.4107  0.0317 -0.6823  0.2255
-1.3854  0.4953 -0.2160  0.2435
[torch.FloatTensor of size 4x4]

>>> torch.mean(a, 1)

-0.8545
 0.0997
 0.2464
-0.2157
[torch.FloatTensor of size 4]

>>> torch.mean(a, 1, True)

-0.8545
 0.0997
 0.2464
-0.2157
[torch.FloatTensor of size 4x1]
torch.median()
torch.median(input) → float

Returns the median value of all elements in the input tensor.

Parameters:input (Tensor) – the input tensor

Example:

>>> a = torch.randn(1, 3)
>>> a

 0.4729 -0.2266 -0.2085
[torch.FloatTensor of size 1x3]

>>> torch.median(a)
-0.2085
torch.median(input, dim=-1, keepdim=False, values=None, indices=None) -> (Tensor, LongTensor)

Returns the median value of each row of the input tensor in the given dimension dim. Also returns the index location of the median value as a LongTensor.

By default, dim is the last dimension of the input tensor.

If keepdim is True, the output tensors are of the same size as input except in the dimension dim where they are of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the outputs tensor having 1 fewer dimension than input.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension to reduce
  • keepdim (bool) – whether the output tensors have dim retained or not
  • values (Tensor, optional) – the output tensor
  • indices (Tensor, optional) – the output index tensor

Example:

>>> a

 -0.6891 -0.6662
 0.2697  0.7412
 0.5254 -0.7402
 0.5528 -0.2399
[torch.FloatTensor of size 4x2]

>>> a = torch.randn(4, 5)
>>> a

 0.4056 -0.3372  1.0973 -2.4884  0.4334
 2.1336  0.3841  0.1404 -0.1821 -0.7646
-0.2403  1.3975 -2.0068  0.1298  0.0212
-1.5371 -0.7257 -0.4871 -0.2359 -1.1724
[torch.FloatTensor of size 4x5]

>>> torch.median(a, 1)
(
 0.4056
 0.1404
 0.0212
-0.7257
[torch.FloatTensor of size 4]
,
 0
 2
 4
 1
[torch.LongTensor of size 4]
)
torch.mode(input, dim=-1, keepdim=False, values=None, indices=None) -> (Tensor, LongTensor)

Returns the mode value of each row of the input tensor in the given dimension dim. Also returns the index location of the mode value as a LongTensor.

By default, dim is the last dimension of the input tensor.

If keepdim is True, the output tensors are of the same size as input except in the dimension dim where they are of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensors having 1 fewer dimension than input.

Note

This function is not defined for torch.cuda.Tensor yet.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension to reduce
  • keepdim (bool) – whether the output tensors have dim retained or not
  • values (Tensor, optional) – the output tensor
  • indices (Tensor, optional) – the output index tensor

Example:

>>> a

 -0.6891 -0.6662
 0.2697  0.7412
 0.5254 -0.7402
 0.5528 -0.2399
[torch.FloatTensor of size 4x2]

>>> a = torch.randn(4, 5)
>>> a

 0.4056 -0.3372  1.0973 -2.4884  0.4334
 2.1336  0.3841  0.1404 -0.1821 -0.7646
-0.2403  1.3975 -2.0068  0.1298  0.0212
-1.5371 -0.7257 -0.4871 -0.2359 -1.1724
[torch.FloatTensor of size 4x5]

>>> torch.mode(a, 1)
(
-2.4884
-0.7646
-2.0068
-1.5371
[torch.FloatTensor of size 4]
,
 3
 4
 2
 0
[torch.LongTensor of size 4]
)
torch.norm()
torch.norm(input, p=2) → float

Returns the p-norm of the input tensor.

Parameters:
  • input (Tensor) – the input tensor
  • p (float, optional) – the exponent value in the norm formulation

Example:

>>> a = torch.randn(1, 3)
>>> a

-0.4376 -0.5328  0.9547
[torch.FloatTensor of size 1x3]

>>> torch.norm(a, 3)
1.0338925067372466
torch.norm(input, p, dim, keepdim=False, out=None) → Tensor

Returns the p-norm of each row of the input tensor in the given dimension dim.

If keepdim is True, the output tensor is of the same size as input except in the dimension dim where it is of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensor having 1 fewer dimension than input.

Parameters:
  • input (Tensor) – the input tensor
  • p (float) – the exponent value in the norm formulation
  • dim (int) – the dimension to reduce
  • keepdim (bool) – whether the output tensor has dim retained or not
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4, 2)
>>> a

-0.6891 -0.6662
 0.2697  0.7412
 0.5254 -0.7402
 0.5528 -0.2399
[torch.FloatTensor of size 4x2]

>>> torch.norm(a, 2, 1)

 0.9585
 0.7888
 0.9077
 0.6026
[torch.FloatTensor of size 4]

>>> torch.norm(a, 0, 1, True)

 2
 2
 2
 2
[torch.FloatTensor of size 4x1]
torch.prod()
torch.prod(input) → float

Returns the product of all elements in the input tensor.

Parameters:input (Tensor) – the input tensor

Example:

>>> a = torch.randn(1, 3)
>>> a

 0.6170  0.3546  0.0253
[torch.FloatTensor of size 1x3]

>>> torch.prod(a)
0.005537458061418483
torch.prod(input, dim, keepdim=False, out=None) → Tensor

Returns the product of each row of the input tensor in the given dimension dim.

If keepdim is True, the output tensor is of the same size as input except in the dimension dim where it is of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensor having 1 fewer dimension than input.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension to reduce
  • keepdim (bool) – whether the output tensor has dim retained or not
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4, 2)
>>> a

 0.1598 -0.6884
-0.1831 -0.4412
-0.9925 -0.6244
-0.2416 -0.8080
[torch.FloatTensor of size 4x2]

>>> torch.prod(a, 1)

-0.1100
 0.0808
 0.6197
 0.1952
[torch.FloatTensor of size 4]
torch.std()
torch.std(input, unbiased=True) → float

Returns the standard-deviation of all elements in the input tensor.

If unbiased is False, then the standard-deviation will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.

Parameters:
  • input (Tensor) – the input tensor
  • unbiased (bool) – whether to use the unbiased estimation or not

Example:

>>> a = torch.randn(1, 3)
>>> a

-1.3063  1.4182 -0.3061
[torch.FloatTensor of size 1x3]

>>> torch.std(a)
1.3782334731508061
torch.std(input, dim, keepdim=False, unbiased=True, out=None) → Tensor

Returns the standard-deviation of each row of the input tensor in the given dimension dim.

If keepdim is True, the output tensor is of the same size as input except in the dimension dim where it is of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensor having 1 fewer dimension than input.

If unbiased is False, then the standard-deviation will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension to reduce
  • keepdim (bool) – whether the output tensor has dim retained or not
  • unbiased (bool) – whether to use the unbiased estimation or not
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4, 4)
>>> a

 0.1889 -2.4856  0.0043  1.8169
-0.7701 -0.4682 -2.2410  0.4098
 0.1919 -1.1856 -1.0361  0.9085
 0.0173  1.0662  0.2143 -0.5576
[torch.FloatTensor of size 4x4]

>>> torch.std(a, dim=1)

 1.7756
 1.1025
 1.0045
 0.6725
[torch.FloatTensor of size 4]
torch.sum()
torch.sum(input) → float

Returns the sum of all elements in the input tensor.

Parameters:input (Tensor) – the input tensor

Example:

>>> a = torch.randn(1, 3)
>>> a

 0.6170  0.3546  0.0253
[torch.FloatTensor of size 1x3]

>>> torch.sum(a)
0.9969287421554327
torch.sum(input, dim, keepdim=False, out=None) → Tensor

Returns the sum of each row of the input tensor in the given dimension dim.

If keepdim is True, the output tensor is of the same size as input except in the dimension dim where it is of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensor having 1 fewer dimension than input.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension to reduce
  • keepdim (bool) – whether the output tensor has dim retained or not
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4, 4)
>>> a

-0.4640  0.0609  0.1122  0.4784
-1.3063  1.6443  0.4714 -0.7396
-1.3561 -0.1959  1.0609 -1.9855
 2.6833  0.5746 -0.5709 -0.4430
[torch.FloatTensor of size 4x4]

>>> torch.sum(a, 1)

 0.1874
 0.0698
-2.4767
 2.2440
[torch.FloatTensor of size 4]
torch.var()
torch.var(input, unbiased=True) → float

Returns the variance of all elements in the input tensor.

If unbiased is False, then the variance will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.

Parameters:
  • input (Tensor) – the input tensor
  • unbiased (bool) – whether to use the unbiased estimation or not

Example:

>>> a = torch.randn(1, 3)
>>> a

-1.3063  1.4182 -0.3061
[torch.FloatTensor of size 1x3]

>>> torch.var(a)
1.899527506513334
torch.var(input, dim, keepdim=False, unbiased=True, out=None) → Tensor

Returns the variance of each row of the input tensor in the given dimension dim.

If keepdim is True, the output tensors are of the same size as input except in the dimension dim where they are of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the outputs tensor having 1 fewer dimension than input.

If unbiased is False, then the variance will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension to reduce
  • keepdim (bool) – whether the output tensor has dim retained or not
  • unbiased (bool) – whether to use the unbiased estimation or not
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4, 4)
>>> a

-1.2738 -0.3058  0.1230 -1.9615
 0.8771 -0.5430 -0.9233  0.9879
 1.4107  0.0317 -0.6823  0.2255
-1.3854  0.4953 -0.2160  0.2435
[torch.FloatTensor of size 4x4]

>>> torch.var(a, 1)

 0.8859
 0.9509
 0.7548
 0.6949
[torch.FloatTensor of size 4]

Comparison Ops

torch.eq(input, other, out=None) → Tensor

Computes element-wise equality

The second argument can be a number or a tensor whose shape is broadcastable with the first argument.

Parameters:
  • input (Tensor) – the tensor to compare
  • other (Tensor or float) – the tensor or value to compare
  • out (Tensor, optional) – the output tensor. Must be a ByteTensor or the same type as input.
Returns:

A torch.ByteTensor containing a 1 at each location where the tensors are equal and a 0 at every other location

Return type:

Tensor

Example:

>>> torch.eq(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]]))
1  0
0  1
[torch.ByteTensor of size 2x2]
torch.equal(tensor1, tensor2) → bool

True if two tensors have the same size and elements, False otherwise.

Example:

>>> torch.equal(torch.Tensor([1, 2]), torch.Tensor([1, 2]))
True
torch.ge(input, other, out=None) → Tensor

Computes input >= other element-wise.

The second argument can be a number or a tensor whose shape is broadcastable with the first argument.

Parameters:
  • input (Tensor) – the tensor to compare
  • other (Tensor or float) – the tensor or value to compare
  • out (Tensor, optional) – the output tensor that must be a ByteTensor or the same type as input
Returns:

A torch.ByteTensor containing a 1 at each location where comparison is true

Return type:

Tensor

Example:

>>> torch.ge(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]]))
 1  1
 0  1
[torch.ByteTensor of size 2x2]
torch.gt(input, other, out=None) → Tensor

Computes input > other element-wise.

The second argument can be a number or a tensor whose shape is broadcastable with the first argument.

Parameters:
  • input (Tensor) – the tensor to compare
  • other (Tensor or float) – the tensor or value to compare
  • out (Tensor, optional) – the output tensor that must be a ByteTensor or the same type as input
Returns:

A torch.ByteTensor containing a 1 at each location where comparison is true

Return type:

Tensor

Example:

>>> torch.gt(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]]))
 0  1
 0  0
[torch.ByteTensor of size 2x2]
torch.kthvalue(input, k, dim=None, keepdim=False, out=None) -> (Tensor, LongTensor)

Returns the k th smallest element of the given input tensor along a given dimension.

If dim is not given, the last dimension of the input is chosen.

A tuple of (values, indices) is returned, where the indices is the indices of the kth-smallest element in the original input tensor in dimension dim.

If keepdim is True, both the values and indices tensors are the same size as input, except in the dimension dim where they are of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in both the values and indices tensors having 1 fewer dimension than the input tensor.

Parameters:
  • input (Tensor) – the input tensor
  • k (int) – k for the k-th smallest element
  • dim (int, optional) – the dimension to find the kth value along
  • keepdim (bool) – whether the output tensors have dim retained or not
  • out (tuple, optional) – the output tuple of (Tensor, LongTensor) can be optionally given to be used as output buffers

Example:

>>> x = torch.arange(1, 6)
>>> x

 1
 2
 3
 4
 5
[torch.FloatTensor of size 5]

>>> torch.kthvalue(x, 4)
(
 4
[torch.FloatTensor of size 1]
,
 3
[torch.LongTensor of size 1]
)

>>> x=torch.arange(1,7).resize_(2,3)
>>> x

1  2  3
4  5  6
[torch.FloatTensor of size 2x3]

>>> torch.kthvalue(x,2,0,True)
(
4  5  6
[torch.FloatTensor of size 1x3]
       ,
1  1  1
[torch.LongTensor of size 1x3]
)
torch.le(input, other, out=None) → Tensor

Computes input <= other element-wise.

The second argument can be a number or a tensor whose shape is broadcastable with the first argument.

Parameters:
  • input (Tensor) – the tensor to compare
  • other (Tensor or float) – the tensor or value to compare
  • out (Tensor, optional) – the output tensor that must be a ByteTensor or the same type as input
Returns:

A torch.ByteTensor containing a 1 at each location where comparison is true

Return type:

Tensor

Example:

>>> torch.le(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]]))
 1  0
 1  1
[torch.ByteTensor of size 2x2]
torch.lt(input, other, out=None) → Tensor

Computes input < other element-wise.

The second argument can be a number or a tensor whose shape is broadcastable with the first argument.

Parameters:
  • input (Tensor) – the tensor to compare
  • other (Tensor or float) – the tensor or value to compare
  • out (Tensor, optional) – the output tensor that must be a ByteTensor or the same type as input
Returns:

A torch.ByteTensor containing a 1 at each location where comparison is true

Return type:

Tensor

Example:

>>> torch.lt(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]]))
 0  0
 1  0
[torch.ByteTensor of size 2x2]
torch.max()
torch.max(input) → float

Returns the maximum value of all elements in the input tensor.

Parameters:input (Tensor) – the input tensor

Example:

>>> a = torch.randn(1, 3)
>>> a

 0.4729 -0.2266 -0.2085
[torch.FloatTensor of size 1x3]

>>> torch.max(a)
0.4729
torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)

Returns the maximum value of each row of the input tensor in the given dimension dim. The second return value is the index location of each maximum value found (argmax).

If keepdim is True, the output tensors are of the same size as input except in the dimension dim where they are of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensors having 1 fewer dimension than input.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension to reduce
  • keepdim (bool) – whether the output tensors have dim retained or not
  • out (tuple, optional) – the result tuple of two output tensors (max, max_indices)

Example:

>> a = torch.randn(4, 4)
>> a

0.0692  0.3142  1.2513 -0.5428
0.9288  0.8552 -0.2073  0.6409
1.0695 -0.0101 -2.4507 -1.2230
0.7426 -0.7666  0.4862 -0.6628
torch.FloatTensor of size 4x4]

>>> torch.max(a, 1)
(
 1.2513
 0.9288
 1.0695
 0.7426
[torch.FloatTensor of size 4]
,
 2
 0
 0
 0
[torch.LongTensor of size 4]
)
torch.max(input, other, out=None) → Tensor

Each element of the tensor input is compared with the corresponding element of the tensor other and an element-wise max is taken.

The shapes of input and other don’t need to match, but they must be broadcastable.

outi=max

Note

When the shapes do not match, the shape of the returned output tensor follows the broadcasting rules.

Parameters:
  • input (Tensor) – the input tensor
  • other (Tensor) – the second input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 1.3869
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]

>>> b = torch.randn(4)
>>> b

 1.0067
-0.8010
 0.6258
 0.3627
[torch.FloatTensor of size 4]

>>> torch.max(a, b)

 1.3869
 0.3912
 0.6258
 0.3627
[torch.FloatTensor of size 4]
torch.min()
torch.min(input) → float

Returns the minimum value of all elements in the input tensor.

Parameters:input (Tensor) – the input tensor

Example:

>>> a = torch.randn(1, 3)
>>> a

 0.4729 -0.2266 -0.2085
[torch.FloatTensor of size 1x3]

>>> torch.min(a)
-0.22663167119026184
torch.min(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)

Returns the minimum value of each row of the input tensor in the given dimension dim. The second return value is the index location of each minimum value found (argmin).

If keepdim is True, the output tensors are of the same size as input except in the dimension dim where they are of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensors having 1 fewer dimension than input.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension to reduce
  • keepdim (bool) – whether the output tensors have dim retained or not
  • out (tuple, optional) – the tuple of two output tensors (min, min_indices)

Example:

>> a = torch.randn(4, 4)
>> a

0.0692  0.3142  1.2513 -0.5428
0.9288  0.8552 -0.2073  0.6409
1.0695 -0.0101 -2.4507 -1.2230
0.7426 -0.7666  0.4862 -0.6628
torch.FloatTensor of size 4x4]

>> torch.min(a, 1)

0.5428
0.2073
2.4507
0.7666
torch.FloatTensor of size 4]

3
2
2
1
torch.LongTensor of size 4]
torch.min(input, other, out=None) → Tensor

Each element of the tensor input is compared with the corresponding element of the tensor other and an element-wise min is taken. The resulting tensor is returned.

The shapes of input and other don’t need to match, but they must be broadcastable.

out_i = \min(tensor_i, other_i)

Note

When the shapes do not match, the shape of the returned output tensor follows the broadcasting rules.

Parameters:
  • input (Tensor) – the input tensor
  • other (Tensor) – the second input tensor
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4)
>>> a

 1.3869
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]

>>> b = torch.randn(4)
>>> b

 1.0067
-0.8010
 0.6258
 0.3627
[torch.FloatTensor of size 4]

>>> torch.min(a, b)

 1.0067
-0.8010
-0.8634
-0.5468
[torch.FloatTensor of size 4]
torch.ne(input, other, out=None) → Tensor

Computes input != other element-wise.

The second argument can be a number or a tensor whose shape is broadcastable with the first argument.

Parameters:
  • input (Tensor) – the tensor to compare
  • other (Tensor or float) – the tensor or value to compare
  • out (Tensor, optional) – the output tensor that must be a ByteTensor or the same type as input
Returns:

A torch.ByteTensor containing a 1 at each location where comparison is true.

Return type:

Tensor

Example:

>>> torch.ne(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]]))
 0  1
 1  0
[torch.ByteTensor of size 2x2]
torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)

Sorts the elements of the input tensor along a given dimension in ascending order by value.

If dim is not given, the last dimension of the input is chosen.

If descending is True then the elements are sorted in descending order by value.

A tuple of (sorted_tensor, sorted_indices) is returned, where the sorted_indices are the indices of the elements in the original input tensor.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int, optional) – the dimension to sort along
  • descending (bool, optional) – controls the sorting order (ascending or descending)
  • out (tuple, optional) – the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers

Example:

>>> x = torch.randn(3, 4)
>>> sorted, indices = torch.sort(x)
>>> sorted

-1.6747  0.0610  0.1190  1.4137
-1.4782  0.7159  1.0341  1.3678
-0.3324 -0.0782  0.3518  0.4763
[torch.FloatTensor of size 3x4]

>>> indices

 0  1  3  2
 2  1  0  3
 3  1  0  2
[torch.LongTensor of size 3x4]

>>> sorted, indices = torch.sort(x, 0)
>>> sorted

-1.6747 -0.0782 -1.4782 -0.3324
 0.3518  0.0610  0.4763  0.1190
 1.0341  0.7159  1.4137  1.3678
[torch.FloatTensor of size 3x4]

>>> indices

 0  2  1  2
 2  0  2  0
 1  1  0  1
[torch.LongTensor of size 3x4]
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

Returns the k largest elements of the given input tensor along a given dimension.

If dim is not given, the last dimension of the input is chosen.

If largest is False then the k smallest elements are returned.

A tuple of (values, indices) is returned, where the indices are the indices of the elements in the original input tensor.

The boolean option sorted if True, will make sure that the returned k elements are themselves sorted

Parameters:
  • input (Tensor) – the input tensor
  • k (int) – the k in “top-k”
  • dim (int, optional) – the dimension to sort along
  • largest (bool, optional) – controls whether to return largest or smallest elements
  • sorted (bool, optional) – controls whether to return the elements in sorted order
  • out (tuple, optional) – the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers

Example:

>>> x = torch.arange(1, 6)
>>> x

 1
 2
 3
 4
 5
[torch.FloatTensor of size 5]

>>> torch.topk(x, 3)
(
 5
 4
 3
[torch.FloatTensor of size 3]
,
 4
 3
 2
[torch.LongTensor of size 3]
)
>>> torch.topk(x, 3, 0, largest=False)
(
 1
 2
 3
[torch.FloatTensor of size 3]
,
 0
 1
 2
[torch.LongTensor of size 3]
)

Other Operations

torch.cross(input, other, dim=-1, out=None) → Tensor

Returns the cross product of vectors in dimension dim of input and other.

input and other must have the same size, and the size of their dim dimension should be 3.

If dim is not given, it defaults to the first dimension found with the size 3.

Parameters:
  • input (Tensor) – the input tensor
  • other (Tensor) – the second input tensor
  • dim (int, optional) – the dimension to take the cross-product in.
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(4, 3)
>>> a

-0.6652 -1.0116 -0.6857
 0.2286  0.4446 -0.5272
 0.0476  0.2321  1.9991
 0.6199  1.1924 -0.9397
[torch.FloatTensor of size 4x3]

>>> b = torch.randn(4, 3)
>>> b

-0.1042 -1.1156  0.1947
 0.9947  0.1149  0.4701
-1.0108  0.8319 -0.0750
 0.9045 -1.3754  1.0976
[torch.FloatTensor of size 4x3]

>>> torch.cross(a, b, dim=1)

-0.9619  0.2009  0.6367
 0.2696 -0.6318 -0.4160
-1.6805 -2.0171  0.2741
 0.0163 -1.5304 -1.9311
[torch.FloatTensor of size 4x3]

>>> torch.cross(a, b)

-0.9619  0.2009  0.6367
 0.2696 -0.6318 -0.4160
-1.6805 -2.0171  0.2741
 0.0163 -1.5304 -1.9311
[torch.FloatTensor of size 4x3]
torch.diag(input, diagonal=0, out=None) → Tensor
  • If input is a vector (1-D tensor), then returns a 2-D square tensor with the elements of input as the diagonal.
  • If input is a matrix (2-D tensor), then returns a 1-D tensor with the diagonal elements of input.

The argument diagonal controls which diagonal to consider:

  • If diagonal = 0, it is the main diagonal.
  • If diagonal > 0, it is above the main diagonal.
  • If diagonal < 0, it is below the main diagonal.
Parameters:
  • input (Tensor) – the input tensor
  • diagonal (int, optional) – the diagonal to consider
  • out (Tensor, optional) – the output tensor

Example:

Get the square matrix where the input vector is the diagonal:

>>> a = torch.randn(3)
>>> a

 1.0480
-2.3405
-1.1138
[torch.FloatTensor of size 3]

>>> torch.diag(a)

 1.0480  0.0000  0.0000
 0.0000 -2.3405  0.0000
 0.0000  0.0000 -1.1138
[torch.FloatTensor of size 3x3]

>>> torch.diag(a, 1)

 0.0000  1.0480  0.0000  0.0000
 0.0000  0.0000 -2.3405  0.0000
 0.0000  0.0000  0.0000 -1.1138
 0.0000  0.0000  0.0000  0.0000
[torch.FloatTensor of size 4x4]

Get the k-th diagonal of a given matrix:

>>> a = torch.randn(3, 3)
>>> a

-1.5328 -1.3210 -1.5204
 0.8596  0.0471 -0.2239
-0.6617  0.0146 -1.0817
[torch.FloatTensor of size 3x3]

>>> torch.diag(a, 0)

-1.5328
 0.0471
-1.0817
[torch.FloatTensor of size 3]

>>> torch.diag(a, 1)

-1.3210
-0.2239
[torch.FloatTensor of size 2]
torch.histc(input, bins=100, min=0, max=0, out=None) → Tensor

Computes the histogram of a tensor.

The elements are sorted into equal width bins between min and max. If min and max are both zero, the minimum and maximum values of the data are used.

Parameters:
  • input (Tensor) – the input tensor
  • bins (int) – number of histogram bins
  • min (int) – lower end of the range (inclusive)
  • max (int) – upper end of the range (inclusive)
  • out (Tensor, optional) – the output tensor
Returns:

Histogram represented as a tensor

Return type:

Tensor

Example:

>>> torch.histc(torch.FloatTensor([1, 2, 1]), bins=4, min=0, max=3)
FloatTensor([0, 2, 1, 0])
torch.renorm(input, p, dim, maxnorm, out=None) → Tensor

Returns a tensor where each sub-tensor of input along dimension dim is normalized such that the p-norm of the sub-tensor is lower than the value maxnorm

Note

If the norm of a row is lower than maxnorm, the row is unchanged

Parameters:
  • input (Tensor) – the input tensor
  • p (float) – the power for the norm computation
  • dim (int) – the dimension to slice over to get the sub-tensors
  • maxnorm (float) – the maximum norm to keep each sub-tensor under
  • out (Tensor, optional) – the output tensor

Example:

>>> x = torch.ones(3, 3)
>>> x[1].fill_(2)
>>> x[2].fill_(3)
>>> x

 1  1  1
 2  2  2
 3  3  3
[torch.FloatTensor of size 3x3]

>>> torch.renorm(x, 1, 0, 5)

 1.0000  1.0000  1.0000
 1.6667  1.6667  1.6667
 1.6667  1.6667  1.6667
[torch.FloatTensor of size 3x3]
torch.trace(input) → float

Returns the sum of the elements of the diagonal of the input 2-D matrix.

Example:

>>> x = torch.arange(1, 10).view(3, 3)
>>> x

 1  2  3
 4  5  6
 7  8  9
[torch.FloatTensor of size 3x3]

>>> torch.trace(x)
15.0
torch.tril(input, diagonal=0, out=None) → Tensor

Returns the lower triangular part of the matrix (2-D tensor) input, the other elements of the result tensor out are set to 0.

The lower triangular part of the matrix is defined as the elements on and below the diagonal.

The argument diagonal controls which diagonal to consider:

  • If diagonal = 0, it is the main diagonal.
  • If diagonal > 0, it is above the main diagonal.
  • If diagonal < 0, it is below the main diagonal.
Parameters:
  • input (Tensor) – the input tensor
  • diagonal (int, optional) – the diagonal to consider
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(3,3)
>>> a

 1.3225  1.7304  1.4573
-0.3052 -0.3111 -0.1809
 1.2469  0.0064 -1.6250
[torch.FloatTensor of size 3x3]

>>> torch.tril(a)

 1.3225  0.0000  0.0000
-0.3052 -0.3111  0.0000
 1.2469  0.0064 -1.6250
[torch.FloatTensor of size 3x3]

>>> torch.tril(a, diagonal=1)

 1.3225  1.7304  0.0000
-0.3052 -0.3111 -0.1809
 1.2469  0.0064 -1.6250
[torch.FloatTensor of size 3x3]

>>> torch.tril(a, diagonal=-1)

 0.0000  0.0000  0.0000
-0.3052  0.0000  0.0000
 1.2469  0.0064  0.0000
[torch.FloatTensor of size 3x3]
torch.triu(input, diagonal=0, out=None) → Tensor

Returns the upper triangular part of the matrix (2-D tensor) input, the other elements of the result tensor out are set to 0.

The upper triangular part of the matrix is defined as the elements on and above the diagonal.

The argument diagonal controls which diagonal to consider:

  • If diagonal = 0, it is the main diagonal.
  • If diagonal > 0, it is above the main diagonal.
  • If diagonal < 0, it is below the main diagonal.
Parameters:
  • input (Tensor) – the input tensor
  • diagonal (int, optional) – the diagonal to consider
  • out (Tensor, optional) – the output tensor

Example:

>>> a = torch.randn(3,3)
>>> a

 1.3225  1.7304  1.4573
-0.3052 -0.3111 -0.1809
 1.2469  0.0064 -1.6250
[torch.FloatTensor of size 3x3]

>>> torch.triu(a)

 1.3225  1.7304  1.4573
 0.0000 -0.3111 -0.1809
 0.0000  0.0000 -1.6250
[torch.FloatTensor of size 3x3]

>>> torch.triu(a, diagonal=1)

 0.0000  1.7304  1.4573
 0.0000  0.0000 -0.1809
 0.0000  0.0000  0.0000
[torch.FloatTensor of size 3x3]

>>> torch.triu(a, diagonal=-1)

 1.3225  1.7304  1.4573
-0.3052 -0.3111 -0.1809
 0.0000  0.0064 -1.6250
[torch.FloatTensor of size 3x3]

BLAS and LAPACK Operations

torch.addbmm(beta=1, mat, alpha=1, batch1, batch2, out=None) → Tensor

Performs a batch matrix-matrix product of matrices stored in batch1 and batch2, with a reduced add step (all matrix multiplications get accumulated along the first dimension). mat is added to the final result.

batch1 and batch2 must be 3-D tensors each containing the same number of matrices.

If batch1 is a (b \times n \times m) tensor, batch2 is a (b \times m \times p) tensor, :mat must be broadcastable with a (n \times p) tensor and attr:out will be a (n \times p) tensor.

out = \beta\ mat + \alpha\ (\sum_{i=0}^{b} batch1_i \mathbin{@} batch2_i)

For inputs of type FloatTensor or DoubleTensor, args beta and alpha must be real numbers, otherwise they should be integers.

Parameters:
  • beta (Number, optional) – multiplier for mat
  • mat (Tensor) – matrix to be added
  • alpha (Number, optional) – multiplier for batch1 @ batch2
  • batch1 (Tensor) – the first batch of matrices to be multiplied
  • batch2 (Tensor) – the second batch of matrices to be multiplied
  • out (Tensor, optional) – the output tensor

Example:

>>> M = torch.randn(3, 5)
>>> batch1 = torch.randn(10, 3, 4)
>>> batch2 = torch.randn(10, 4, 5)
>>> torch.addbmm(M, batch1, batch2)

 -3.1162  11.0071   7.3102   0.1824  -7.6892
  1.8265   6.0739   0.4589  -0.5641  -5.4283
 -9.3387  -0.1794  -1.2318  -6.8841  -4.7239
[torch.FloatTensor of size 3x5]
torch.addmm(beta=1, mat, alpha=1, mat1, mat2, out=None) → Tensor

Performs a matrix multiplication of the matrices mat1 and mat2. The matrix mat is added to the final result.

If mat1 is a (n \times m) tensor, mat2 is a (m \times p) tensor, then mat must be broadcastable with a (n \times p) tensor and out will be a (n \times p) tensor.

alpha and beta are scaling factors on mat1 @ mat2 and mat respectively.

out = \beta\ mat + \alpha\ (mat1_i \mathbin{@} mat2_i)

For inputs of type FloatTensor or DoubleTensor, args beta and alpha must be real numbers, otherwise they should be integers.

Parameters:
  • beta (Number, optional) – multiplier for mat
  • mat (Tensor) – matrix to be added
  • alpha (Number, optional) – multiplier for mat1 @ mat2
  • mat1 (Tensor) – the first matrix to be multiplied
  • mat2 (Tensor) – the second matrix to be multiplied
  • out (Tensor, optional) – the output tensor

Example:

>>> M = torch.randn(2, 3)
>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.addmm(M, mat1, mat2)

-0.4095 -1.9703  1.3561
 5.7674 -4.9760  2.7378
[torch.FloatTensor of size 2x3]
torch.addmv(beta=1, tensor, alpha=1, mat, vec, out=None) → Tensor

Performs a matrix-vector product of the matrix mat and the vector vec. The vector tensor is added to the final result.

If mat is a (n \times m) tensor, vec is a 1-D tensor of size m, then tensor must be broadcastable with a 1-D tensor of size n and out will be 1-D tensor of size n.

alpha and beta are scaling factors on mat * vec and tensor respectively.

out = \beta\ tensor + \alpha\ (mat \mathbin{@} vec)

For inputs of type FloatTensor or DoubleTensor, args beta and alpha must be real numbers, otherwise they should be integers

Parameters:
  • beta (Number, optional) – multiplier for tensor
  • tensor (Tensor) – vector to be added
  • alpha (Number, optional) – multiplier for mat @ vec
  • mat (Tensor) – matrix to be multiplied
  • vec (Tensor) – vector to be multiplied
  • out (Tensor, optional) – the output tensor

Example:

>>> M = torch.randn(2)
>>> mat = torch.randn(2, 3)
>>> vec = torch.randn(3)
>>> torch.addmv(M, mat, vec)

-2.0939
-2.2950
[torch.FloatTensor of size 2]
torch.addr(beta=1, mat, alpha=1, vec1, vec2, out=None) → Tensor

Performs the outer-product of vectors vec1 and vec2 and adds it to the matrix mat.

Optional values beta and alpha are scalars that multiply mat and (vec1 \otimes vec2) respectively.

out = \beta\ mat + \alpha\ (vec1 \otimes vec2)

If vec1 is a vector of size n and vec2 is a vector of size m, then mat must be broadcastable with a matrix of size (n \times m) and out will be a matrix of size (n \times m).

For inputs of type FloatTensor or DoubleTensor, args beta and alpha must be real numbers, otherwise they should be integers

Parameters:
  • beta (Number, optional) – multiplier for mat
  • mat (Tensor) – matrix to be added
  • alpha (Number, optional) – multiplier for vec1 \otimes vec2
  • vec1 (Tensor) – the first vector of the outer product
  • vec2 (Tensor) – the second vector of the outer product
  • out (Tensor, optional) – the output tensor

Example:

>>> vec1 = torch.arange(1, 4)
>>> vec2 = torch.arange(1, 3)
>>> M = torch.zeros(3, 2)
>>> torch.addr(M, vec1, vec2)
 1  2
 2  4
 3  6
[torch.FloatTensor of size 3x2]
torch.baddbmm(beta=1, mat, alpha=1, batch1, batch2, out=None) → Tensor

Performs a batch matrix-matrix product of matrices in batch1 and batch2. mat is added to the final result.

batch1 and batch2 must be 3-D tensors each containing the same number of matrices.

If batch1 is a (b \times n \times m) tensor, batch2 is a (b \times m \times p) tensor, then mat must be broadcastable with a (b \times n \times p) tensor and out will be a (b \times n \times p) tensor.

out_i = \beta\ mat_i + \alpha\ (batch1_i \mathbin{@} batch2_i)

For inputs of type FloatTensor or DoubleTensor, args beta and alpha must be real numbers, otherwise they should be integers.

Parameters:
  • beta (Number, optional) – multiplier for mat
  • mat (Tensor) – the tensor to be added
  • alpha (Number, optional) – multiplier for batch1 @ batch2
  • batch1 (Tensor) – the first batch of matrices to be multiplied
  • batch2 (Tensor) – the second batch of matrices to be multiplied
  • out (Tensor, optional) – the output tensor

Example:

>>> M = torch.randn(10, 3, 5)
>>> batch1 = torch.randn(10, 3, 4)
>>> batch2 = torch.randn(10, 4, 5)
>>> torch.baddbmm(M, batch1, batch2).size()
torch.Size([10, 3, 5])
torch.bmm(batch1, batch2, out=None) → Tensor

Performs a batch matrix-matrix product of matrices stored in batch1 and batch2.

batch1 and batch2 must be 3-D tensors each containing the same number of matrices.

If batch1 is a (b \times n \times m) tensor, batch2 is a (b \times m \times p) tensor, out will be a (b \times n \times p) tensor.

out_i = batch1_i \mathbin{@} batch2_i

Note

This function does not broadcast. For broadcasting matrix products, see torch.matmul().

Parameters:
  • batch1 (Tensor) – the first batch of matrices to be multiplied
  • batch2 (Tensor) – the second batch of matrices to be multiplied
  • out (Tensor, optional) – the output tensor

Example:

>>> batch1 = torch.randn(10, 3, 4)
>>> batch2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(batch1, batch2)
>>> res.size()
torch.Size([10, 3, 5])
torch.btrifact(A, info=None, pivot=True) → Tensor, IntTensor

Batch LU factorization.

Returns a tuple containing the LU factorization and pivots. The optional argument info provides information if the factorization succeeded for each minibatch example. The info values are from dgetrf and a non-zero value indicates an error occurred. The specific values are from cublas if cuda is being used, otherwise LAPACK. Pivoting is done if pivot is set.

Parameters:A (Tensor) – the tensor to factor

Example:

>>> A = torch.randn(2, 3, 3)
>>> A_LU = A.btrifact()
torch.btrisolve(b, LU_data, LU_pivots) → Tensor

Batch LU solve.

Returns the LU solve of the linear system Ax = b.

Parameters:
  • b (Tensor) – the RHS tensor
  • LU_data (Tensor) – the pivoted LU factorization of A from btrifact().
  • LU_pivots (IntTensor) – the pivots of the LU factorization

Example:

>>> A = torch.randn(2, 3, 3)
>>> b = torch.randn(2, 3)
>>> A_LU = torch.btrifact(A)
>>> x = b.btrisolve(*A_LU)
>>> torch.norm(A.bmm(x.unsqueeze(2)) - b)
6.664001874625056e-08
torch.dot(tensor1, tensor2) → float

Computes the dot product (inner product) of two tensors.

Note

This function does not broadcast.

Example:

>>> torch.dot(torch.Tensor([2, 3]), torch.Tensor([2, 1]))
7.0
torch.eig(a, eigenvectors=False, out=None) -> (Tensor, Tensor)

Computes the eigenvalues and eigenvectors of a real square matrix.

Parameters:
  • a (Tensor) – the square matrix for which the eigenvalues and eigenvectors will be computed
  • eigenvectors (bool) – True to compute both eigenvalues and eigenvectors; otherwise, only eigenvalues will be computed
  • out (tuple, optional) – the output tensors
Returns:

A tuple containing

  • e (Tensor): the right eigenvalues of a
  • v (Tensor): the eigenvectors of a if eigenvectors is True; otherwise an empty tensor

Return type:

(Tensor, Tensor)

torch.gels(B, A, out=None) → Tensor

Computes the solution to the least squares and least norm problems for a full rank m by n matrix A.

If m >= n, gels() solves the least-squares problem:

\begin{array}{ll} \mbox{minimize} & \|AX-B\|_F. \end{array}

If m < n, gels() solves the least-norm problem:

\begin{array}{ll} \mbox{minimize} & \|X\|_F & \mbox{subject to} & AX = B. \end{array}

The first n rows of the returned matrix X contains the solution. The remaining rows contain residual information: the euclidean norm of each column starting at row n is the residual for the corresponding column.

Parameters:
  • B (Tensor) – the matrix B
  • A (Tensor) – the m by n matrix A
  • out (tuple, optional) – the optional destination tensor
Returns:

A tuple containing:

  • X (Tensor): the least squares solution
  • qr (Tensor): the details of the QR factorization

Return type:

(Tensor, Tensor)

Note

The returned matrices will always be transposed, irrespective of the strides of the input matrices. That is, they will have stride (1, m) instead of (m, 1).

Example:

>>> A = torch.Tensor([[1, 1, 1],
...                   [2, 3, 4],
...                   [3, 5, 2],
...                   [4, 2, 5],
...                   [5, 4, 3]])
>>> B = torch.Tensor([[-10, -3],
                      [ 12, 14],
                      [ 14, 12],
                      [ 16, 16],
                      [ 18, 16]])
>>> X, _ = torch.gels(B, A)
>>> X
2.0000  1.0000
1.0000  1.0000
1.0000  2.0000
[torch.FloatTensor of size 3x2]
torch.geqrf(input, out=None) -> (Tensor, Tensor)

This is a low-level function for calling LAPACK directly.

You’ll generally want to use torch.qr() instead.

Computes a QR decomposition of input, but without constructing Q and R as explicit separate matrices.

Rather, this directly calls the underlying LAPACK function ?geqrf which produces a sequence of ‘elementary reflectors’.

See LAPACK documentation for further details.

Parameters:
  • input (Tensor) – the input matrix
  • out (tuple, optional) – the output tuple of (Tensor, Tensor)
torch.ger(vec1, vec2, out=None) → Tensor

Outer product of vec1 and vec2. If vec1 is a vector of size n and vec2 is a vector of size m, then out must be a matrix of size (n \times m).

Note

This function does not broadcast.

Parameters:
  • vec1 (Tensor) – 1-D input vector
  • vec2 (Tensor) – 1-D input vector
  • out (Tensor, optional) – optional output matrix

Example:

>>> v1 = torch.arange(1, 5)
>>> v2 = torch.arange(1, 4)
>>> torch.ger(v1, v2)

  1   2   3
  2   4   6
  3   6   9
  4   8  12
[torch.FloatTensor of size 4x3]
torch.gesv(B, A, out=None) -> (Tensor, Tensor)

X, LU = torch.gesv(B, A) returns the solution to the system of linear equations represented by AX = B

LU contains L and U factors for LU factorization of A.

A has to be a square and non-singular matrix (2-D tensor).

If A is an (m \times m) matrix and B is (m \times k), the result LU is (m \times m) and X is (m \times k).

Note

Irrespective of the original strides, the returned matrices X and LU will be transposed, i.e. with strides (1, m) instead of (m, 1).

Parameters:
  • B (Tensor) – input matrix of (m \times k) dimensions
  • A (Tensor) – input square matrix of (m \times m) dimensions
  • out (Tensor, optional) – optional output matrix

Example:

>>> A = torch.Tensor([[6.80, -2.11,  5.66,  5.97,  8.23],
...                   [-6.05, -3.30,  5.36, -4.44,  1.08],
...                   [-0.45,  2.58, -2.70,  0.27,  9.04],
...                   [8.32,  2.71,  4.35,  -7.17,  2.14],
...                   [-9.67, -5.14, -7.26,  6.08, -6.87]]).t()
>>> B = torch.Tensor([[4.02,  6.19, -8.22, -7.57, -3.03],
...                   [-1.56,  4.00, -8.67,  1.75,  2.86],
...                   [9.81, -4.09, -4.57, -8.61,  8.99]]).t()
>>> X, LU = torch.gesv(B, A)
>>> torch.dist(B, torch.mm(A, X))
9.250057093890353e-06
torch.inverse(input, out=None) → Tensor

Takes the inverse of the square matrix input.

Note

Irrespective of the original strides, the returned matrix will be transposed, i.e. with strides (1, m) instead of (m, 1)

Parameters:
  • input (Tensor) – the input 2-D square tensor
  • out (Tensor, optional) – the optional output tensor

Example:

>>> x = torch.rand(10, 10)
>>> x

 0.7800  0.2267  0.7855  0.9479  0.5914  0.7119  0.4437  0.9131  0.1289  0.1982
 0.0045  0.0425  0.2229  0.4626  0.6210  0.0207  0.6338  0.7067  0.6381  0.8196
 0.8350  0.7810  0.8526  0.9364  0.7504  0.2737  0.0694  0.5899  0.8516  0.3883
 0.6280  0.6016  0.5357  0.2936  0.7827  0.2772  0.0744  0.2627  0.6326  0.9153
 0.7897  0.0226  0.3102  0.0198  0.9415  0.9896  0.3528  0.9397  0.2074  0.6980
 0.5235  0.6119  0.6522  0.3399  0.3205  0.5555  0.8454  0.3792  0.4927  0.6086
 0.1048  0.0328  0.5734  0.6318  0.9802  0.4458  0.0979  0.3320  0.3701  0.0909
 0.2616  0.3485  0.4370  0.5620  0.5291  0.8295  0.7693  0.1807  0.0650  0.8497
 0.1655  0.2192  0.6913  0.0093  0.0178  0.3064  0.6715  0.5101  0.2561  0.3396
 0.4370  0.4695  0.8333  0.1180  0.4266  0.4161  0.0699  0.4263  0.8865  0.2578
[torch.FloatTensor of size 10x10]

>>> x = torch.rand(10, 10)
>>> y = torch.inverse(x)
>>> z = torch.mm(x, y)
>>> z

 1.0000  0.0000  0.0000 -0.0000  0.0000  0.0000  0.0000  0.0000 -0.0000 -0.0000
 0.0000  1.0000 -0.0000  0.0000  0.0000  0.0000 -0.0000 -0.0000 -0.0000 -0.0000
 0.0000  0.0000  1.0000 -0.0000 -0.0000  0.0000  0.0000  0.0000 -0.0000 -0.0000
 0.0000  0.0000  0.0000  1.0000  0.0000  0.0000  0.0000 -0.0000 -0.0000  0.0000
 0.0000  0.0000 -0.0000 -0.0000  1.0000  0.0000  0.0000 -0.0000 -0.0000 -0.0000
 0.0000  0.0000  0.0000 -0.0000  0.0000  1.0000 -0.0000 -0.0000 -0.0000 -0.0000
 0.0000  0.0000  0.0000 -0.0000  0.0000  0.0000  1.0000  0.0000 -0.0000  0.0000
 0.0000  0.0000 -0.0000 -0.0000  0.0000  0.0000 -0.0000  1.0000 -0.0000  0.0000
-0.0000  0.0000 -0.0000 -0.0000  0.0000  0.0000 -0.0000 -0.0000  1.0000 -0.0000
-0.0000  0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000 -0.0000  0.0000  1.0000
[torch.FloatTensor of size 10x10]

>>> torch.max(torch.abs(z - torch.eye(10))) # Max nonzero
5.096662789583206e-07
torch.matmul(tensor1, tensor2, out=None)[source]

Matrix product of two tensors.

The behavior depends on the dimensionality of the tensors as follows:

  • If both tensors are 1-dimensional, the dot product (scalar) is returned.
  • If both arguments are 2-dimensional, the matrix-matrix product is returned.
  • If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
  • If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.
  • If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if tensor1 is a (j \times 1 \times n \times m) tensor and tensor2 is a (k \times m \times p) tensor, out will be an (j \times k \times n \times p) tensor.

Note

The 1-dimensional dot product version of this function does not support an out parameter.

Parameters:
  • tensor1 (Tensor) – the first tensor to be multiplied
  • tensor2 (Tensor) – the second tensor to be multiplied
  • out (Tensor, optional) – the output tensor
torch.mm(mat1, mat2, out=None) → Tensor

Performs a matrix multiplication of the matrices mat1 and mat2.

If mat1 is a (n \times m) tensor, mat2 is a (m \times p) tensor, out will be a (n \times p) tensor.

Note

This function does not broadcast. For broadcasting matrix products, see torch.matmul().

Parameters:
  • mat1 (Tensor) – the first matrix to be multiplied
  • mat2 (Tensor) – the second matrix to be multiplied
  • out (Tensor, optional) – the output tensor

Example:

>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.mm(mat1, mat2)
 0.0519 -0.3304  1.2232
 4.3910 -5.1498  2.7571
[torch.FloatTensor of size 2x3]
torch.mv(mat, vec, out=None) → Tensor

Performs a matrix-vector product of the matrix mat and the vector vec.

If mat is a (n \times m) tensor, vec is a 1-D tensor of size m, out will be 1-D of size n.

Note

This function does not broadcast.

Parameters:
  • mat (Tensor) – matrix to be multiplied
  • vec (Tensor) – vector to be multiplied
  • out (Tensor, optional) – the output tensor

Example:

>>> mat = torch.randn(2, 3)
>>> vec = torch.randn(3)
>>> torch.mv(mat, vec)
-2.0939
-2.2950
[torch.FloatTensor of size 2]
torch.orgqr(a, tau) → Tensor

Computes the orthogal matrix Q of a QR factorization, from the (a, tau) tuple returned by torch.geqrf().

This directly calls the underlying LAPACK function ?orgqr. See ?orgqr LAPACK documentation for further details.

Parameters:
torch.ormqr(a, tau, mat, left=True, transpose=False) -> (Tensor, Tensor)

Multiplies mat by the orthogonal Q matrix of the QR factorization formed by torch.geqrf() that is represented by (a, tau).

This directly calls the underlying LAPACK function ?ormqr. See ?ormqr LAPACK documentation for further details.

torch.potrf(a, out=None)

potrf(a, upper, out=None)

Computes the Cholesky decomposition of a positive semidefinite matrix a: returns matrix u If upper is True or not provided, u is upper triangular such that a = u^T u. If upper is False, u is lower triangular such that a = u u^T.

Parameters:
  • a (Tensor) – the input 2-D tensor, a symmetric positive semidefinite matrix
  • upper (bool, optional) – whether to return a upper (default) or lower triangular matrix
  • out (Tensor, optional) – the output tensor for u

Example:

>>> a = torch.randn(3,3)
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
>>> u = torch.potrf(a)
>>> a

 2.3563  3.2318 -0.9406
 3.2318  4.9557 -2.1618
-0.9406 -2.1618  2.2443
[torch.FloatTensor of size 3x3]

>>> u

 1.5350  2.1054 -0.6127
 0.0000  0.7233 -1.2053
 0.0000  0.0000  0.6451
[torch.FloatTensor of size 3x3]

>>> torch.mm(u.t(),u)

 2.3563  3.2318 -0.9406
 3.2318  4.9557 -2.1618
-0.9406 -2.1618  2.2443
[torch.FloatTensor of size 3x3]
torch.potri(u, out=None)

potri(u, upper, out=None)

Computes the inverse of a positive semidefinite matrix given its Cholesky factor u: returns matrix inv If upper is True or not provided, u is upper triangular such that inv = (u^T u)^{-1}. If upper is False, u is lower triangular such that inv = (u u^T)^{-1}.

Parameters:
  • u (Tensor) – the input 2-D tensor, a upper or lower triangular Cholesky factor
  • upper (bool, optional) – whether to return a upper (default) or lower triangular matrix
  • out (Tensor, optional) – the output tensor for inv

Example:

>>> a = torch.randn(3,3)
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
>>> u = torch.potrf(a)
>>> a

 2.3563  3.2318 -0.9406
 3.2318  4.9557 -2.1618
-0.9406 -2.1618  2.2443
[torch.FloatTensor of size 3x3]

>>> torch.potri(u)

 12.5724 -10.1765  -4.5333
-10.1765   8.5852   4.0047
 -4.5333   4.0047   2.4031
[torch.FloatTensor of size 3x3]

>>> a.inverse()

 12.5723 -10.1765  -4.5333
-10.1765   8.5852   4.0047
 -4.5333   4.0047   2.4031
[torch.FloatTensor of size 3x3]
torch.potrs(b, u, out=None)

potrs(b, u, upper, out=None)

Solves a linear system of equations with a positive semidefinite matrix to be inverted given its given a Cholesky factor matrix u: returns matrix c If upper is True or not provided, u is and upper triangular such that c = (u^T u)^{-1} b. If upper is False, u is and lower triangular such that c = (u u^T)^{-1} b.

Note

b is always a 2-D tensor, use b.unsqueeze(1) to convert a vector.

Parameters:
  • b (Tensor) – the right hand side 2-D tensor
  • u (Tensor) – the input 2-D tensor, a upper or lower triangular Cholesky factor
  • upper (bool, optional) – whether to return a upper (default) or lower triangular matrix
  • out (Tensor, optional) – the output tensor for c

Example:

>>> a = torch.randn(3,3)
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
>>> u = torch.potrf(a)
>>> a

 2.3563  3.2318 -0.9406
 3.2318  4.9557 -2.1618
-0.9406 -2.1618  2.2443
[torch.FloatTensor of size 3x3]

>>> b = torch.randn(3,2)
>>> b

-0.3119 -1.8224
-0.2798  0.1789
-0.3735  1.7451
[torch.FloatTensor of size 3x2]

>>> torch.potrs(b,u)

 0.6187 -32.6438
-0.7234  27.0703
-0.6039  13.1717
[torch.FloatTensor of size 3x2]

>>> torch.mm(a.inverse(),b)

 0.6187 -32.6436
-0.7234  27.0702
-0.6039  13.1717
[torch.FloatTensor of size 3x2]
torch.pstrf(a, out=None)

pstrf(a, upper, out=None)

Computes the pivoted Cholesky decomposition of a positive semidefinite matrix a: returns matrices u and piv. If upper is True or not provided, u is and upper triangular such that a = p^T u^T u p, with p the permutation given by piv. If upper is False, u is and lower triangular such that a = p^T u u^T p.

Parameters:
  • a (Tensor) – the input 2-D tensor
  • upper (bool, optional) – whether to return a upper (default) or lower triangular matrix
  • out (tuple, optional) – tuple of u and piv tensors

Example:

>>> a = torch.randn(3,3)
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
>>> a

 5.4417 -2.5280  1.3643
-2.5280  2.9689 -2.1368
 1.3643 -2.1368  4.6116
[torch.FloatTensor of size 3x3]

>>> u,piv = torch.pstrf(a)
>>> u

 2.3328  0.5848 -1.0837
 0.0000  2.0663 -0.7274
 0.0000  0.0000  1.1249
[torch.FloatTensor of size 3x3]

>>> piv

 0
 2
 1
[torch.IntTensor of size 3]

>>> p = torch.eye(3).index_select(0,piv.long()).index_select(0,piv.long()).t() # make pivot permutation
>>> torch.mm(torch.mm(p.t(),torch.mm(u.t(),u)),p) # reconstruct

 5.4417  1.3643 -2.5280
 1.3643  4.6116 -2.1368
-2.5280 -2.1368  2.9689
[torch.FloatTensor of size 3x3]
torch.qr(input, out=None) -> (Tensor, Tensor)

Computes the QR decomposition of a matrix input: returns matrices q and r such that x = q r, with q being an orthogonal matrix and r being an upper triangular matrix.

This returns the thin (reduced) QR factorization.

Note

precision may be lost if the magnitudes of the elements of input are large

Note

while it should always give you a valid decomposition, it may not give you the same one across platforms - it will depend on your LAPACK implementation.

Note

Irrespective of the original strides, the returned matrix q will be transposed, i.e. with strides (1, m) instead of (m, 1).

Parameters:
  • input (Tensor) – the input 2-D tensor
  • out (tuple, optional) – tuple of Q and R tensors

Example:

>>> a = torch.Tensor([[12, -51, 4], [6, 167, -68], [-4, 24, -41]])
>>> q, r = torch.qr(a)
>>> q

-0.8571  0.3943  0.3314
-0.4286 -0.9029 -0.0343
 0.2857 -0.1714  0.9429
[torch.FloatTensor of size 3x3]

>>> r

 -14.0000  -21.0000   14.0000
   0.0000 -175.0000   70.0000
   0.0000    0.0000  -35.0000
[torch.FloatTensor of size 3x3]

>>> torch.mm(q, r).round()

  12  -51    4
   6  167  -68
  -4   24  -41
[torch.FloatTensor of size 3x3]

>>> torch.mm(q.t(), q).round()

 1 -0  0
-0  1  0
 0  0  1
[torch.FloatTensor of size 3x3]
torch.svd(input, some=True, out=None) -> (Tensor, Tensor, Tensor)

U, S, V = torch.svd(A) returns the singular value decomposition of a real matrix A of size (n x m) such that A = USV^T.

U is of shape (n \times n).

S is a diagonal matrix of shape (n \times m), represented as a vector of size \min(n, m) containing the diagonal entries.

V is of shape (m \times m).

If some is True (default), the returned U and V matrices will contain only min(n, m) orthonormal columns.

Note

Irrespective of the original strides, the returned matrix U will be transposed, i.e. with strides (1, n) instead of (n, 1).

Parameters:
  • input (Tensor) – the input 2-D tensor
  • some (bool, optional) – controls the shape of returned U and V
  • out (tuple, optional) – the output tuple of tensors

Example:

>>> a = torch.Tensor([[8.79,  6.11, -9.15,  9.57, -3.49,  9.84],
...                   [9.93,  6.91, -7.93,  1.64,  4.02,  0.15],
...                   [9.83,  5.04,  4.86,  8.83,  9.80, -8.99],
...                   [5.45, -0.27,  4.85,  0.74, 10.00, -6.02],
...                   [3.16,  7.98,  3.01,  5.80,  4.27, -5.31]]).t()
>>> a

  8.7900   9.9300   9.8300   5.4500   3.1600
  6.1100   6.9100   5.0400  -0.2700   7.9800
 -9.1500  -7.9300   4.8600   4.8500   3.0100
  9.5700   1.6400   8.8300   0.7400   5.8000
 -3.4900   4.0200   9.8000  10.0000   4.2700
  9.8400   0.1500  -8.9900  -6.0200  -5.3100
[torch.FloatTensor of size 6x5]

>>> u, s, v = torch.svd(a)
>>> u

-0.5911  0.2632  0.3554  0.3143  0.2299
-0.3976  0.2438 -0.2224 -0.7535 -0.3636
-0.0335 -0.6003 -0.4508  0.2334 -0.3055
-0.4297  0.2362 -0.6859  0.3319  0.1649
-0.4697 -0.3509  0.3874  0.1587 -0.5183
 0.2934  0.5763 -0.0209  0.3791 -0.6526
[torch.FloatTensor of size 6x5]

>>> s

 27.4687
 22.6432
  8.5584
  5.9857
  2.0149
[torch.FloatTensor of size 5]

>>> v

-0.2514  0.8148 -0.2606  0.3967 -0.2180
-0.3968  0.3587  0.7008 -0.4507  0.1402
-0.6922 -0.2489 -0.2208  0.2513  0.5891
-0.3662 -0.3686  0.3859  0.4342 -0.6265
-0.4076 -0.0980 -0.4932 -0.6227 -0.4396
[torch.FloatTensor of size 5x5]

>>> torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t()))
8.934150226306685e-06
torch.symeig(input, eigenvectors=False, upper=True, out=None) -> (Tensor, Tensor)

e, V = torch.symeig(input) returns eigenvalues and eigenvectors of a real symmetric matrix input.

input and V are (m \times m) matrices and e is a m dimensional vector.

This function calculates all eigenvalues (and vectors) of input such that input = V diag(e) V^T.

The boolean argument eigenvectors defines computation of eigenvectors or eigenvalues only.

If it is False, only eigenvalues are computed. If it is True, both eigenvalues and eigenvectors are computed.

Since the input matrix input is supposed to be symmetric, only the upper triangular portion is used by default.

If upper is False, then lower triangular portion is used.

Note: Irrespective of the original strides, the returned matrix V will be transposed, i.e. with strides (1, m) instead of (m, 1).

Parameters:
  • input (Tensor) – the input symmetric matrix
  • eigenvectors (boolean, optional) – controls whether eigenvectors have to be computed
  • upper (boolean, optional) – controls whether to consider upper-triangular or lower-triangular region
  • out (tuple, optional) – the output tuple of (Tensor, Tensor)

Examples:

>>> a = torch.Tensor([[ 1.96,  0.00,  0.00,  0.00,  0.00],
...                   [-6.49,  3.80,  0.00,  0.00,  0.00],
...                   [-0.47, -6.39,  4.17,  0.00,  0.00],
...                   [-7.20,  1.50, -1.51,  5.70,  0.00],
...                   [-0.65, -6.34,  2.67,  1.80, -7.10]]).t()

>>> e, v = torch.symeig(a, eigenvectors=True)
>>> e

-11.0656
 -6.2287
  0.8640
  8.8655
 16.0948
[torch.FloatTensor of size 5]

>>> v

-0.2981 -0.6075  0.4026 -0.3745  0.4896
-0.5078 -0.2880 -0.4066 -0.3572 -0.6053
-0.0816 -0.3843 -0.6600  0.5008  0.3991
-0.0036 -0.4467  0.4553  0.6204 -0.4564
-0.8041  0.4480  0.1725  0.3108  0.1622
[torch.FloatTensor of size 5x5]
torch.trtrs(b, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)

Solves a system of equations with a triangular coefficient matrix A and multiple right-hand sides b.

In particular, solves AX = b and assumes A is upper-triangular with the default keyword arguments.

This method is NOT implemented for CUDA tensors.

Parameters:
  • A (Tensor) – the input triangular coefficient matrix
  • b (Tensor) – multiple right-hand sides. Each column of b is a right-hand side for the system of equations.
  • upper (bool, optional) – whether to solve the upper-triangular system of equations (default) or the lower-triangular system of equations. Default: True.
  • transpose (bool, optional) – whether A should be transposed before being sent into the solver. Default: False.
  • unitriangular (bool, optional) – whether A is unit triangular. If True, the diagonal elements of A are assumed to be 1 and not referenced from A. Default: False.
Returns:

A tuple (X, M) where M is a clone of A and X is the solution to AX = b (or whatever variant of the system of equations, depending on the keyword arguments.)

Shape:
  • A: (N, N)
  • b: (N, C)
  • output[0]: (N, C)
  • output[1]: (N, N)

Examples:

>>> A = torch.randn(2,2).triu()
>>> A

-1.8793  0.1567
 0.0000 -2.1972
[torch.FloatTensor of size 2x2]

>>> b = torch.randn(2,3)
>>> b

 1.8776 -0.0759  1.6590
-0.5676  0.4771  0.7477
[torch.FloatTensor of size 2x3]

>>> torch.trtrs(b, A)

(
 -0.9775  0.0223 -0.9112
  0.2583 -0.2172 -0.3403
 [torch.FloatTensor of size 2x3],
 -1.8793  0.1567
  0.0000 -2.1972
 [torch.FloatTensor of size 2x2])