Source code for torch.functional

import torch
from operator import mul
from functools import reduce
import math

__all__ = [
    'argmax',
    'argmin',
    'bartlett_window',
    'btrifact',
    'btriunpack',
    'hamming_window',
    'hann_window',
    'isnan',
    'split',
    'unbind',
    'unique',
]


[docs]def split(tensor, split_size_or_sections, dim=0): r"""Splits the tensor into chunks. If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension :attr:`dim= is not divisible by :attr:`split_size`. If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according to :attr:`split_size_or_sections`. Arguments: tensor (Tensor): tensor to split. split_size_or_sections (int) or (list(int)): size of a single chunk or list of sizes for each chunk dim (int): dimension along which to split the tensor. """ # Overwriting reason: # This dispatches to two ATen functions depending on the type of # split_size_or_sections. The branching code is in tensor.py, which we # call here. return tensor.split(split_size_or_sections, dim)
[docs]def btrifact(A, info=None, pivot=True): r"""Batch LU factorization. Returns a tuple containing the LU factorization and pivots. Pivoting is done if :attr:`pivot` is set. The optional argument :attr:`info` stores information if the factorization succeeded for each minibatch example. The :attr:`info` is provided as an `IntTensor`, its values will be filled from dgetrf and a non-zero value indicates an error occurred. Specifically, the values are from cublas if cuda is being used, otherwise LAPACK. .. warning:: The :attr:`info` argument is deprecated in favor of :meth:`torch.btrifact_with_info`. Arguments: A (Tensor): the tensor to factor info (IntTensor, optional): (deprecated) an `IntTensor` to store values indicating whether factorization succeeds pivot (bool, optional): controls whether pivoting is done Returns: A tuple containing factorization and pivots. Example:: >>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = torch.btrifact(A) >>> A_LU tensor([[[ 1.3506, 2.5558, -0.0816], [ 0.1684, 1.1551, 0.1940], [ 0.1193, 0.6189, -0.5497]], [[ 0.4526, 1.2526, -0.3285], [-0.7988, 0.7175, -0.9701], [ 0.2634, -0.9255, -0.3459]]]) >>> pivots tensor([[ 3, 3, 3], [ 3, 3, 3]], dtype=torch.int32) """ # Overwriting reason: # `info` is being deprecated in favor of `btrifact_with_info`. This warning # is in tensor.py, which we call here. return A.btrifact(info, pivot)
[docs]def unbind(tensor, dim=0): r"""Removes a tensor dimension. Returns a tuple of all slices along a given dimension, already without it. Arguments: tensor (Tensor): the tensor to unbind dim (int): dimension to remove """ return tuple(tensor.select(dim, i) for i in range(tensor.size(dim)))
[docs]def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): r"""Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor. Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``. Arguments: LU_data (Tensor): the packed LU factorization data LU_pivots (Tensor): the packed LU factorization pivots unpack_data (bool): flag indicating if the data should be unpacked unpack_pivots (bool): flag indicating if the pivots should be unpacked Example:: >>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = A.btrifact() >>> P, A_L, A_U = torch.btriunpack(A_LU, pivots) >>> >>> # can recover A from factorization >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U)) """ nBatch, sz, _ = LU_data.size() if unpack_data: I_U = torch.triu(torch.ones(sz, sz)).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz) I_L = 1 - I_U L = LU_data.new(LU_data.size()).zero_() U = LU_data.new(LU_data.size()).zero_() I_diag = torch.eye(sz).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz) L[I_diag] = 1.0 L[I_L] = LU_data[I_L] U[I_U] = LU_data[I_U] else: L = U = None if unpack_pivots: P = torch.eye(sz).type_as(LU_data).unsqueeze(0).repeat(nBatch, 1, 1) for i in range(nBatch): for j in range(sz): k = int(LU_pivots[i, j] - 1) t = P[i, :, j].clone() P[i, :, j] = P[i, :, k] P[i, :, k] = t else: P = None return P, L, U
[docs]def hann_window(window_length, periodic=True, dtype=torch.float32): r"""Hann window function. This method computes the Hann window function: .. math:: w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = \sin^2 \left( \frac{\pi n}{N - 1} \right), where :math:`N` is the full window size. The input :attr:`window_length` is a positive integer controlling the returned window size. :attr:`periodic` flag determines whether the returned window trims off the last duplicate value from the symmetric window and is ready to be used as a periodic window with functions like :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in above formula is in fact :math:`\text{window_length} + 1`. Also, we always have ``torch.hann_window(L, periodic=True)`` equal to ``torch.hann_window(L + 1, periodic=False)[:-1])``. .. note:: If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. Arguments: window_length (int): the size of returned window periodic (bool, optional): If True, returns a window to be used as periodic function. If False, return a symmetric window. dtype (:class:`torch.dtype`, optional): the desired type of returned window. Default: `torch.float32` Returns: Tensor: A 1-D tensor of size :math:`(\text{window_length},)` containing the window """ if not dtype.is_floating_point: raise ValueError("dtype must be a floating point type, but got dtype={}".format(dtype)) if window_length <= 0: raise ValueError('window_length must be positive') return hamming_window(window_length, periodic=periodic, alpha=0.5, beta=0.5, dtype=dtype)
[docs]def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, dtype=torch.float32): r"""Hamming window function. This method computes the Hamming window function: .. math:: w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), where :math:`N` is the full window size. The input :attr:`window_length` is a positive integer controlling the returned window size. :attr:`periodic` flag determines whether the returned window trims off the last duplicate value from the symmetric window and is ready to be used as a periodic window with functions like :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in above formula is in fact :math:`\text{window_length} + 1`. Also, we always have ``torch.hamming_window(L, periodic=True)`` equal to ``torch.hamming_window(L + 1, periodic=False)[:-1])``. .. note:: If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. .. note:: This is a generalized version of :meth:`torch.hann_window`. Arguments: window_length (int): the size of returned window periodic (bool, optional): If True, returns a window to be used as periodic function. If False, return a symmetric window. dtype (:class:`torch.dtype`, optional): the desired type of returned window. Default: `torch.float32` Returns: Tensor: A 1-D tensor of size :math:`(\text{window_length},)` containing the window """ if not dtype.is_floating_point: raise ValueError("dtype must be a floating point type, but got dtype={}".format(dtype)) if window_length <= 0: raise ValueError('window_length must be positive') if window_length == 1: return torch.ones(window_length, dtype=dtype) window_length += int(periodic) window = torch.arange(window_length, dtype=dtype) window = window.mul_(math.pi * 2 / (window_length - 1)).cos_().mul_(-beta).add_(alpha) if periodic: return window[:-1] else: return window
[docs]def bartlett_window(window_length, periodic=True, dtype=torch.float32): r"""Bartlett window function. This method computes the Bartlett window function: .. math:: w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ \end{cases}, where :math:`N` is the full window size. The input :attr:`window_length` is a positive integer controlling the returned window size. :attr:`periodic` flag determines whether the returned window trims off the last duplicate value from the symmetric window and is ready to be used as a periodic window with functions like :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in above formula is in fact :math:`\text{window_length} + 1`. Also, we always have ``torch.bartlett_window(L, periodic=True)`` equal to ``torch.bartlett_window(L + 1, periodic=False)[:-1])``. .. note:: If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. Arguments: window_length (int): the size of returned window periodic (bool, optional): If True, returns a window to be used as periodic function. If False, return a symmetric window. dtype (:class:`torch.dtype`, optional): the desired type of returned window. Default: `torch.float32` Returns: Tensor: A 1-D tensor of size :math:`(\text{window_length},)` containing the window """ if not dtype.is_floating_point: raise ValueError("dtype must be a floating point type, but got dtype={}".format(dtype)) if window_length <= 0: raise ValueError('window_length must be positive') if window_length == 1: return torch.ones(window_length, dtype=dtype) window_length += int(periodic) window = torch.arange(window_length, dtype=dtype).mul_(2.0 / (window_length - 1)) first_half_size = ((window_length - 1) >> 1) + 1 window.narrow(0, first_half_size, window_length - first_half_size).mul_(-1).add_(2) if periodic: return window[:-1] else: return window
[docs]def isnan(tensor): r"""Returns a new tensor with boolean elements representing if each element is `NaN` or not. Arguments: tensor (Tensor): A tensor to check Returns: Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements. Example:: >>> torch.isnan(torch.tensor([1, float('nan'), 2])) tensor([ 0, 1, 0], dtype=torch.uint8) """ if not isinstance(tensor, torch.Tensor): raise ValueError("The argument is not a tensor") return tensor != tensor
[docs]def unique(input, sorted=False, return_inverse=False): r"""Returns the unique scalar elements of the input tensor as a 1-D tensor. Arguments: input (Tensor): the input tensor sorted (bool): Whether to sort the unique elements in ascending order before returning as output. return_inverse (bool): Whether to also return the indices for where elements in the original input ended up in the returned unique list. Returns: (Tensor, Tensor (optional)): A tensor or a tuple of tensors containing - **output** (*Tensor*): the output list of unique scalar elements. - **inverse_indices** (*Tensor*): (optional) if :attr:`return_inverse` is True, there will be a 2nd returned tensor (same shape as input) representing the indices for where elements in the original input map to in the output; otherwise, this function will only return a single tensor. Example:: >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long)) >>> output tensor([ 2, 3, 1]) >>> output, inverse_indices = torch.unique( torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True) >>> output tensor([ 1, 2, 3]) >>> inverse_indices tensor([ 0, 2, 1, 2]) >>> output, inverse_indices = torch.unique( torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True) >>> output tensor([ 1, 2, 3]) >>> inverse_indices tensor([[ 0, 2], [ 1, 2]]) """ output, inverse_indices = torch._unique( input, sorted=sorted, return_inverse=return_inverse, ) if return_inverse: return output, inverse_indices else: return output
[docs]def argmax(input, dim=None, keepdim=False): """Returns the indices of the maximum values of a tensor across a dimension. This is the second value returned by :meth:`torch.max`. See its documentation for the exact semantics of this method. Args: input (Tensor): the input tensor dim (int): the dimension to reduce. If ``None``, the argmax of the flattened input is returned. keepdim (bool): whether the output tensors have :attr:`dim` retained or not. Ignored if ``dim=None``. Example:: >>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a, dim=1) tensor([ 0, 2, 0, 1]) """ if dim is None: return torch._argmax(input.contiguous().view(-1), dim=0, keepdim=False) return torch._argmax(input, dim, keepdim)
[docs]def argmin(input, dim=None, keepdim=False): """Returns the indices of the minimum values of a tensor across a dimension. This is the second value returned by :meth:`torch.min`. See its documentation for the exact semantics of this method. Args: input (Tensor): the input tensor dim (int): the dimension to reduce. If ``None``, the argmin of the flattened input is returned. keepdim (bool): whether the output tensors have :attr:`dim` retained or not. Ignored if ``dim=None``. Example:: >>> a = torch.randn(4, 4) >>> a tensor([[ 0.1139, 0.2254, -0.1381, 0.3687], [ 1.0100, -1.1975, -0.0102, -0.4732], [-0.9240, 0.1207, -0.7506, -1.0213], [ 1.7809, -1.2960, 0.9384, 0.1438]]) >>> torch.argmin(a, dim=1) tensor([ 2, 1, 3, 1]) """ if dim is None: return torch._argmin(input.contiguous().view(-1), dim=0, keepdim=False) return torch._argmin(input, dim, keepdim)