# The Tensor classes are added to this module by python_tensor.cppfromtypingimportOptional,Tuple,List,Unionimporttorchfromtorch._Cimport_add_docstr,_sparse# type: ignore[attr-defined]fromtorchimportTensor# A workaround to support both TorchScript and MyPy:fromtypingimportTYPE_CHECKINGifTYPE_CHECKING:fromtorch.typesimport_dtypeasDTypeDimOrDims=Optional[Union[int,Tuple[int],List[int]]]else:# The JIT doesn't understand Union, nor torch.dtype hereDType=intDimOrDims=Optional[Tuple[int]]__all__=['addmm','mm','sum','softmax','log_softmax',]addmm=_add_docstr(_sparse._sparse_addmm,r"""sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> TensorThis function does exact same thing as :func:`torch.addmm` in the forward,except that it supports backward for sparse COO matrix :attr:`mat1`.When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.When inputs are COO tensors, this function also supports backward for both inputs.Supports both CSR and COO storage formats... note:: This function doesn't support computing derivaties with respect to CSR matrices.Args: mat (Tensor): a dense matrix to be added mat1 (Tensor): a sparse matrix to be multiplied mat2 (Tensor): a dense matrix to be multiplied beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)""")mm=_add_docstr(_sparse._sparse_mm,r""" Performs a matrix multiplication of the sparse matrix :attr:`mat1` and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a :math:`(n \times p)` tensor. When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`. When inputs are COO tensors, this function also supports backward for both inputs. Supports both CSR and COO storage formats... note:: This function doesn't support computing derivaties with respect to CSR matrices. Args: mat1 (Tensor): the first sparse matrix to be multiplied mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense Shape: The format of the output tensor of this function follows: - sparse x sparse -> sparse - sparse x dense -> dense Example:: >>> a = torch.randn(2, 3).to_sparse().requires_grad_(True) >>> a tensor(indices=tensor([[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]]), values=tensor([ 1.5901, 0.0183, -0.6146, 1.8061, -0.0112, 0.6302]), size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True) >>> b = torch.randn(3, 2, requires_grad=True) >>> b tensor([[-0.6479, 0.7874], [-1.2056, 0.5641], [-1.1716, -0.9923]], requires_grad=True) >>> y = torch.sparse.mm(a, b) >>> y tensor([[-0.3323, 1.8723], [-1.8951, 0.7904]], grad_fn=<SparseAddmmBackward>) >>> y.sum().backward() >>> a.grad tensor(indices=tensor([[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]]), values=tensor([ 0.1394, -0.6415, -2.1639, 0.1394, -0.6415, -2.1639]), size=(2, 3), nnz=6, layout=torch.sparse_coo) """)sampled_addmm=_add_docstr(_sparse.sparse_sampled_addmm,r"""sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> TensorPerforms a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locationsspecified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result.Mathematically this performs the following operation:.. math:: \text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha`and :attr:`beta` are the scaling factors.:math:`\text{spy}(\text{input})` has value 1 at the positions where :attr:`input` has non-zero values, and 0 elsewhere... note:: :attr:`input` must be a sparse CSR tensor. :attr:`mat1` and :attr:`mat2` must be dense tensors. This function is implemented only for tensors on CUDA devices.Args: input (Tensor): a sparse CSR matrix of shape `(m, n)` to be added and used to compute the sampled matrix multiplication mat1 (Tensor): a dense matrix of shape `(m, k)` to be multiplied mat2 (Tensor): a dense matrix of shape `(k, n)` to be multipliedKeyword args: beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.Examples:: >>> input = torch.eye(3, device='cuda').to_sparse_csr() >>> mat1 = torch.randn(3, 5, device='cuda') >>> mat2 = torch.randn(5, 3, device='cuda') >>> torch.sparse.sampled_addmm(input, mat1, mat2) tensor(crow_indices=tensor([0, 1, 2, 3]), col_indices=tensor([0, 1, 2]), values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0', size=(3, 3), nnz=3, layout=torch.sparse_csr) >>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense() tensor([[ 0.2847, 0.0000, 0.0000], [ 0.0000, -0.7805, 0.0000], [ 0.0000, 0.0000, -0.1900]], device='cuda:0') >>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5) tensor(crow_indices=tensor([0, 1, 2, 3]), col_indices=tensor([0, 1, 2]), values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0', size=(3, 3), nnz=3, layout=torch.sparse_csr)""")
[docs]defsum(input:Tensor,dim:DimOrDims=None,dtype:Optional[DType]=None)->Tensor:r""" Returns the sum of each row of the sparse tensor :attr:`input` in the given dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions, reduce over all of them. When sum over all ``sparse_dim``, this method returns a dense tensor instead of a sparse tensor. All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output tensor having :attr:`dim` fewer dimensions than :attr:`input`. During backward, only gradients at ``nnz`` locations of :attr:`input` will propagate back. Note that the gradients of :attr:`input` is coalesced. Args: input (Tensor): the input sparse tensor dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce over all dims. dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. Default: dtype of :attr:`input`. Example:: >>> nnz = 3 >>> dims = [5, 5, 2, 3] >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)), torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz) >>> V = torch.randn(nnz, dims[2], dims[3]) >>> size = torch.Size(dims) >>> S = torch.sparse_coo_tensor(I, V, size) >>> S tensor(indices=tensor([[2, 0, 3], [2, 4, 1]]), values=tensor([[[-0.6438, -1.6467, 1.4004], [ 0.3411, 0.0918, -0.2312]], [[ 0.5348, 0.0634, -2.0494], [-0.7125, -1.0646, 2.1844]], [[ 0.1276, 0.1874, -0.6334], [-1.9682, -0.5340, 0.7483]]]), size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo) # when sum over only part of sparse_dims, return a sparse tensor >>> torch.sparse.sum(S, [1, 3]) tensor(indices=tensor([[0, 2, 3]]), values=tensor([[-1.4512, 0.4073], [-0.8901, 0.2017], [-0.3183, -1.7539]]), size=(5, 2), nnz=3, layout=torch.sparse_coo) # when sum over all sparse dim, return a dense tensor # with summed dims squeezed >>> torch.sparse.sum(S, [0, 1, 3]) tensor([-2.6596, -1.1450]) """ifdtypeisNone:ifdimisnotNone:returntorch._sparse_sum(input,dim)else:returntorch._sparse_sum(input)else:ifdimisnotNone:returntorch._sparse_sum(input,dim,dtype=dtype)else:returntorch._sparse_sum(input,dtype=dtype)
softmax=_add_docstr(_sparse._sparse_softmax,r"""sparse.softmax(input, dim, *, dtype=None) -> TensorApplies a softmax function.Softmax is defined as::math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}`where :math:`i, j` run over sparse tensor indices and unspecifiedentries are ignores. This is equivalent to defining unspecifiedentries as negative infinity so that :math:`exp(x_k) = 0` when theentry with index :math:`k` has not specified.It is applied to all slices along `dim`, and will re-scale them sothat the elements lie in the range `[0, 1]` and sum to 1.Args: input (Tensor): input dim (int): A dimension along which softmax will be computed. dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None""")log_softmax=_add_docstr(_sparse._sparse_log_softmax,r"""sparse.log_softmax(input, dim, *, dtype=None) -> TensorApplies a softmax function followed by logarithm.See :class:`~torch.sparse.softmax` for more details.Args: input (Tensor): input dim (int): A dimension along which softmax will be computed. dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. If specified, the input tensor is casted to :attr:`dtype` before the operation is performed. This is useful for preventing data type overflows. Default: None""")
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.