Sparse semantics
Introduction
Sparsity in PyTorch is a quickly growing area that has found a lot of support and demand due to its efficiency in both memory and compute. This tutorial is meant to be used in conjunction with the the PyTorch link above, as the sparse tensors are ultimately the building blocks for MaskedTensors (just as regular torch.Tensor
s are as well).
Sparse storage formats have been proven to be powerful in a variety of ways. As a primer, the first use case most practitioners think about is when the majority of elements are equal to zero (a high degree of sparsity), but even in cases of lower sparsity, certain formats (e.g. BSR) can take advantage of substructures within a matrix. There are a number of different sparse storage formats that can be leveraged with various tradeoffs and degrees of adoption.
“Specified” and “unspecified” elements (e.g. elements that are stored vs. not) have a long history in PyTorch without formal semantics and certainly without consistency; indeed, MaskedTensor was partially born out of a build up of issues (e.g. the nan_grad tutorial) that vanilla tensors could not address. A major goal of the MaskedTensor project is to become the primary source of truth for specified/unspecified semantics where they are a first class citizen instead of an afterthought.
Note: Currently, only the COO and CSR sparse storage formats are supported in MaskedTensor (BSR and CSC will be developed in the future). If you have another format that you would like supported, please file an issue!
Principles
input
andmask
must have the same storage format, whether that’storch.strided
,torch.sparse_coo
, ortorch.sparse_csr
.input
andmask
must have the same size, indicated byt.size()
Sparse COO Tensors
import torch
from maskedtensor import masked_tensor
In accordance with Principle #1, a sparse MaskedTensor is created by passing in two sparse tensors, which can be initialized with any of the constructors, e.g. torch.sparse_coo_tensor
.
As a recap of sparse COO tensors, the COO format stands for “Coordinate format”, where the specified elements are stored as tuples of their indices and the corresponding values. That is, the following are provided:
indices
: array of size(ndim, nse)
and dtypetorch.int64
values
: array of size(nse,)
with any integer or floating point number dtype
where ndim
is the dimensionality of the tensor and nse
is the number of specified elements
For both sparse COO and CSR tensors, you can construct them by doing either:
masked_tensor(sparse_tensor_data, sparse_tensor_mask)
dense_masked_tensor.to_sparse_coo()
The is second is easier to illustrate so we have shown that below, but for more on the first and the nuances behind the approach, please read the Appendix at the bottom.
# To start, create a MaskedTensor
values = torch.tensor(
[[0, 0, 3],
[4, 0, 5]]
)
mask = torch.tensor(
[[False, False, True],
[False, False, True]]
)
mt = masked_tensor(values, mask)
sparse_coo_mt = mt.to_sparse_coo()
print("masked tensor:\n", mt)
print("sparse coo masked tensor:\n", sparse_coo_mt)
print("sparse data:\n", sparse_coo_mt.data())
masked tensor:
masked_tensor(
[
[ --, --, 3],
[ --, --, 5]
]
)
sparse coo masked tensor:
masked_tensor(
[
[ --, --, 3],
[ --, --, 5]
]
)
sparse data:
tensor(indices=tensor([[0, 1],
[2, 2]]),
values=tensor([3, 5]),
size=(2, 3), nnz=2, layout=torch.sparse_coo)
Sparse CSR Tensors
Similarly, MaskedTensor also supports the CSR (Compressed Sparse Row) sparse tensor format. Instead of storing the tuples of the indices like sparse COO tensors, sparse CSR tensors aim to decrease the memory requirements by storing compressed row indices. In particular, a CSR sparse tensor consists of three 1-D tensors:
crow_indices
: array of compressed row indices with size(size[0] + 1,)
. This array indicates which row a given entry invalues
lives in. The last element is the number of specified elements, whilecrow_indices[i+1] - crow_indices[i]
indicates the number of specified elements in rowi
.col_indices
: array of size(nnz,)
. Indicates the column indices for each value.values
: array of size(nnz,)
. Contains the values of the CSR tensor.
Of note, both sparse COO and CSR tensors are in a beta state.
By way of example (and again, you can find more examples in the Appendix):
mt_sparse_csr = mt.to_sparse_csr()
print("values:\n", mt_sparse_csr.data())
print("mask:\n", mt_sparse_csr.mask())
print("mt:\n", mt_sparse_csr)
values:
tensor(crow_indices=tensor([0, 1, 2]),
col_indices=tensor([2, 2]),
values=tensor([3, 5]), size=(2, 3), nnz=2, layout=torch.sparse_csr)
mask:
tensor(crow_indices=tensor([0, 1, 2]),
col_indices=tensor([2, 2]),
values=tensor([True, True]), size=(2, 3), nnz=2,
layout=torch.sparse_csr)
mt:
masked_tensor(
[
[ --, --, 3],
[ --, --, 5]
]
)
/home/runner/.local/lib/python3.8/site-packages/maskedtensor/core.py:179: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:66.)
sparse_mask = mask.to_sparse_csr()
Supported Operations
Unary
All unary operations are supported, e.g.:
mt.sin()
masked_tensor(
[
[ --, --, 0.1411],
[ --, --, -0.9589]
]
)
Binary
Binary operations are also supported, but the input masks from the two masked tensors must match.
i = [[0, 1, 1],
[2, 0, 2]]
v1 = [3, 4, 5]
v2 = [20, 30, 40]
m = torch.tensor([True, False, True])
s1 = torch.sparse_coo_tensor(i, v1, (2, 3))
s2 = torch.sparse_coo_tensor(i, v2, (2, 3))
mask = torch.sparse_coo_tensor(i, m, (2, 3))
mt1 = masked_tensor(s1, mask)
mt2 = masked_tensor(s2, mask)
print("mt1:\n", mt1)
print("mt2:\n", mt2)
print("torch.div(mt2, mt1):\n", torch.div(mt2, mt1))
print("torch.mul(mt1, mt2):\n", torch.mul(mt1, mt2))
mt1:
masked_tensor(
[
[ --, --, 3],
[ --, --, 5]
]
)
mt2:
masked_tensor(
[
[ --, --, 20],
[ --, --, 40]
]
)
torch.div(mt2, mt1):
masked_tensor(
[
[ --, --, 6.6667],
[ --, --, 8.0000]
]
)
torch.mul(mt1, mt2):
masked_tensor(
[
[ --, --, 60],
[ --, --, 200]
]
)
Reductions
At the moment, when the underlying data is sparse, only reductions across all dimensions are supported and not a particular dimension (e.g. mt.sum()
is supported but not mt.sum(dim=1)
). This is next in line to work on.
print("mt:\n", mt)
print("mt.sum():\n", mt.sum())
print("mt.amin():\n", mt.amin())
mt:
masked_tensor(
[
[ --, --, 3],
[ --, --, 5]
]
)
mt.sum():
masked_tensor(8, True)
mt.amin():
masked_tensor(3, True)
MaskedTensor methods and sparse
to_dense()
mt.to_dense()
masked_tensor(
[
[ --, --, 3],
[ --, --, 5]
]
)
to_sparse_coo()
v = [[3, 0, 0],
[0, 4, 5]]
m = [[True, False, False],
[False, True, True]]
mt = masked_tensor(torch.tensor(v), torch.tensor(m))
mt_sparse = mt.to_sparse_coo()
to_sparse_csr()
v = [[3, 0, 0],
[0, 4, 5]]
m = [[True, False, False],
[False, True, True]]
mt = masked_tensor(torch.tensor(v), torch.tensor(m))
mt_sparse_csr = mt.to_sparse_csr()
is_sparse
/ is_sparse_coo
/ is_sparse_csr
print("mt.is_sparse: ", mt.is_sparse())
print("mt_sparse.is_sparse: ", mt_sparse.is_sparse())
print("mt.is_sparse_coo: ", mt.is_sparse_coo())
print("mt_sparse.is_sparse_coo: ", mt_sparse.is_sparse_coo())
print("mt.is_sparse_csr: ", mt.is_sparse_csr())
print("mt_sparse_csr.is_sparse_csr: ", mt_sparse_csr.is_sparse_csr())
mt.is_sparse: False
mt_sparse.is_sparse: True
mt.is_sparse_coo: False
mt_sparse.is_sparse_coo: True
mt.is_sparse_csr: False
mt_sparse_csr.is_sparse_csr: True
Appendix
Sparse COO construction
Recall in our original example, we created a MaskedTensor and then converted it to a sparse COO MaskedTensor with mt.to_sparse_coo()
Alternatively, we can also construct a sparse COO MaskedTensor by passing in two sparse COO tensors!
values = torch.tensor([[0, 0, 3], [4, 0, 5]]).to_sparse()
mask = torch.tensor([[False, False, True], [False, False, True]]).to_sparse()
mt = masked_tensor(values, mask)
print("values:\n", values)
print("mask:\n", mask)
print("mt:\n", mt)
values:
tensor(indices=tensor([[0, 1, 1],
[2, 0, 2]]),
values=tensor([3, 4, 5]),
size=(2, 3), nnz=3, layout=torch.sparse_coo)
mask:
tensor(indices=tensor([[0, 1],
[2, 2]]),
values=tensor([True, True]),
size=(2, 3), nnz=2, layout=torch.sparse_coo)
mt:
masked_tensor(
[
[ --, --, 3],
[ --, --, 5]
]
)
Instead of doing dense_tensor.to_sparse()
, we can also create the sparse COO tensors directly, which brings us to a word of warning: when using a function like .to_sparse_coo()
, if the user does not specify the indices like in the above example, then 0 values will be default “unspecified”
i = [[0, 1, 1],
[2, 0, 2]]
v = [3, 4, 5]
m = torch.tensor([True, False, True])
values = torch.sparse_coo_tensor(i, v, (2, 3))
mask = torch.sparse_coo_tensor(i, m, (2, 3))
mt2 = masked_tensor(values, mask)
print("values:\n", values)
print("mask:\n", mask)
print("mt2:\n", mt2)
values:
tensor(indices=tensor([[0, 1, 1],
[2, 0, 2]]),
values=tensor([3, 4, 5]),
size=(2, 3), nnz=3, layout=torch.sparse_coo)
mask:
tensor(indices=tensor([[0, 1, 1],
[2, 0, 2]]),
values=tensor([ True, False, True]),
size=(2, 3), nnz=3, layout=torch.sparse_coo)
mt2:
masked_tensor(
[
[ --, --, 3],
[ --, --, 5]
]
)
Note that mt
and mt2
will have the same value in the vast majority of operations, but this brings us to a note on the implementation under the hood:
input
and mask
- only for sparse formats - can have a different number of elements (tensor.nnz()
) at creation, but the indices of mask
must then be a subset of the indices from input
. In this case, input
will assume the shape of mask using the function input.sparse_mask(mask)
; in other words, any of the elements in input
that are not True
in mask
will be thrown away
Therefore, under the hood, the data looks slightly different; mt2
has the 4 value masked out and mt
is completely without it. In other words, their underlying data still has different shapes, so mt + mt2
is invalid.
print("mt.masked_data:\n", mt.data())
print("mt2.masked_data:\n", mt2.data())
mt.masked_data:
tensor(indices=tensor([[0, 1],
[2, 2]]),
values=tensor([3, 5]),
size=(2, 3), nnz=2, layout=torch.sparse_coo)
mt2.masked_data:
tensor(indices=tensor([[0, 1, 1],
[2, 0, 2]]),
values=tensor([3, 4, 5]),
size=(2, 3), nnz=3, layout=torch.sparse_coo)
Sparse CSR
We can also construct a sparse CSR MaskedTensor using sparse CSR tensors, and like the example above, they have a similar treatment under the hood.
crow_indices = torch.tensor([0, 2, 4])
col_indices = torch.tensor([0, 1, 0, 1])
values = torch.tensor([1, 2, 3, 4])
mask_values = torch.tensor([True, False, False, True])
csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.double)
mask = torch.sparse_csr_tensor(crow_indices, col_indices, mask_values, dtype=torch.bool)
mt = masked_tensor(csr, mask)
print("csr tensor:\n", csr.to_dense())
print("mask csr tensor:\n", mask.to_dense())
print("masked tensor:\n", mt)
csr tensor:
tensor([[1., 2.],
[3., 4.]], dtype=torch.float64)
mask csr tensor:
tensor([[ True, False],
[False, True]])
masked tensor:
masked_tensor(
[
[ 1.0000, --],
[ --, 4.0000]
]
)