torch.masked¶
Introduction¶
Motivation¶
Warning
The PyTorch API of masked tensors is in the prototype stage and may or may not change in the future.
MaskedTensor serves as an extension to torch.Tensor
that provides the user with the ability to:
use any masked semantics (e.g. variable length tensors, nan* operators, etc.)
differentiate between 0 and NaN gradients
various sparse applications (see tutorial below)
“Specified” and “unspecified” have a long history in PyTorch without formal semantics and certainly without
consistency; indeed, MaskedTensor was born out of a build up of issues that the vanilla torch.Tensor
class could not properly address. Thus, a primary goal of MaskedTensor is to become the source of truth for
said “specified” and “unspecified” values in PyTorch where they are a first class citizen instead of an afterthought.
In turn, this should further unlock sparsity’s potential,
enable safer and more consistent operators, and provide a smoother and more intuitive experience
for users and developers alike.
What is a MaskedTensor?¶
A MaskedTensor is a tensor subclass that consists of 1) an input (data), and 2) a mask. The mask tells us which entries from the input should be included or ignored.
By way of example, suppose that we wanted to mask out all values that are equal to 0 (represented by the gray) and take the max:
On top is the vanilla tensor example while the bottom is MaskedTensor where all the 0’s are masked out. This clearly yields a different result depending on whether we have the mask, but this flexible structure allows the user to systematically ignore any elements they’d like during computation.
There are already a number of existing tutorials that we’ve written to help users onboard, such as:
Supported Operators¶
Unary Operators¶
Unary operators are operators that only contain only a single input. Applying them to MaskedTensors is relatively straightforward: if the data is masked out at a given index, we apply the operator, otherwise we’ll continue to mask out the data.
The available unary operators are:
abs |
Computes the absolute value of each element in |
absolute |
Alias for |
acos |
Computes the inverse cosine of each element in |
arccos |
Alias for |
acosh |
Returns a new tensor with the inverse hyperbolic cosine of the elements of |
arccosh |
Alias for |
angle |
Computes the element-wise angle (in radians) of the given |
asin |
Returns a new tensor with the arcsine of the elements of |
arcsin |
Alias for |
asinh |
Returns a new tensor with the inverse hyperbolic sine of the elements of |
arcsinh |
Alias for |
atan |
Returns a new tensor with the arctangent of the elements of |
arctan |
Alias for |
atanh |
Returns a new tensor with the inverse hyperbolic tangent of the elements of |
arctanh |
Alias for |
bitwise_not |
Computes the bitwise NOT of the given input tensor. |
ceil |
Returns a new tensor with the ceil of the elements of |
clamp |
|
clip |
Alias for |
conj_physical |
Computes the element-wise conjugate of the given |
cos |
Returns a new tensor with the cosine of the elements of |
cosh |
Returns a new tensor with the hyperbolic cosine of the elements of |
deg2rad |
Returns a new tensor with each of the elements of |
digamma |
Alias for |
erf |
Alias for |
erfc |
Alias for |
erfinv |
Alias for |
exp |
Returns a new tensor with the exponential of the elements of the input tensor |
exp2 |
Alias for |
expm1 |
Alias for |
fix |
Alias for |
floor |
Returns a new tensor with the floor of the elements of |
frac |
Computes the fractional portion of each element in |
lgamma |
Computes the natural logarithm of the absolute value of the gamma function on |
log |
Returns a new tensor with the natural logarithm of the elements of |
log10 |
Returns a new tensor with the logarithm to the base 10 of the elements of |
log1p |
Returns a new tensor with the natural logarithm of (1 + |
log2 |
Returns a new tensor with the logarithm to the base 2 of the elements of |
logit |
Alias for |
i0 |
Alias for |
isnan |
Returns a new tensor with boolean elements representing if each element of |
nan_to_num |
Replaces |
neg |
Returns a new tensor with the negative of the elements of |
negative |
Alias for |
positive |
Returns |
pow |
Takes the power of each element in |
rad2deg |
Returns a new tensor with each of the elements of |
reciprocal |
Returns a new tensor with the reciprocal of the elements of |
round |
Rounds elements of |
rsqrt |
Returns a new tensor with the reciprocal of the square-root of each of the elements of |
sigmoid |
Alias for |
sign |
Returns a new tensor with the signs of the elements of |
sgn |
This function is an extension of torch.sign() to complex tensors. |
signbit |
Tests if each element of |
sin |
Returns a new tensor with the sine of the elements of |
sinc |
Alias for |
sinh |
Returns a new tensor with the hyperbolic sine of the elements of |
sqrt |
Returns a new tensor with the square-root of the elements of |
square |
Returns a new tensor with the square of the elements of |
tan |
Returns a new tensor with the tangent of the elements of |
tanh |
Returns a new tensor with the hyperbolic tangent of the elements of |
trunc |
Returns a new tensor with the truncated integer values of the elements of |
The available inplace unary operators are all of the above except:
angle |
Computes the element-wise angle (in radians) of the given |
positive |
Returns |
signbit |
Tests if each element of |
isnan |
Returns a new tensor with boolean elements representing if each element of |
Binary Operators¶
As you may have seen in the tutorial, MaskedTensor
also has binary operations implemented with the caveat
that the masks in the two MaskedTensors must match or else an error will be raised. As noted in the error, if you
need support for a particular operator or have proposed semantics for how they should behave instead, please open
an issue on GitHub. For now, we have decided to go with the most conservative implementation to ensure that users
know exactly what is going on and are being intentional about their decisions with masked semantics.
The available binary operators are:
add |
Adds |
atan2 |
Element-wise arctangent of with consideration of the quadrant. |
arctan2 |
Alias for |
bitwise_and |
Computes the bitwise AND of |
bitwise_or |
Computes the bitwise OR of |
bitwise_xor |
Computes the bitwise XOR of |
bitwise_left_shift |
Computes the left arithmetic shift of |
bitwise_right_shift |
Computes the right arithmetic shift of |
div |
Divides each element of the input |
divide |
Alias for |
floor_divide |
|
fmod |
Applies C++'s std::fmod entrywise. |
logaddexp |
Logarithm of the sum of exponentiations of the inputs. |
logaddexp2 |
Logarithm of the sum of exponentiations of the inputs in base-2. |
mul |
Multiplies |
multiply |
Alias for |
nextafter |
Return the next floating-point value after |
remainder |
Computes Python's modulus operation entrywise. |
sub |
Subtracts |
subtract |
Alias for |
true_divide |
Alias for |
eq |
Computes element-wise equality |
ne |
Computes element-wise. |
le |
Computes element-wise. |
ge |
Computes element-wise. |
greater |
Alias for |
greater_equal |
Alias for |
gt |
Computes element-wise. |
less_equal |
Alias for |
lt |
Computes element-wise. |
less |
Alias for |
maximum |
Computes the element-wise maximum of |
minimum |
Computes the element-wise minimum of |
fmax |
Computes the element-wise maximum of |
fmin |
Computes the element-wise minimum of |
not_equal |
Alias for |
The available inplace binary operators are all of the above except:
logaddexp |
Logarithm of the sum of exponentiations of the inputs. |
logaddexp2 |
Logarithm of the sum of exponentiations of the inputs in base-2. |
equal |
|
fmin |
Computes the element-wise minimum of |
minimum |
Computes the element-wise minimum of |
fmax |
Computes the element-wise maximum of |
Reductions¶
The following reductions are available (with autograd support). For more information, the Overview tutorial details some examples of reductions, while the Advanced semantics tutorial has some further in-depth discussions about how we decided on certain reduction semantics.
sum |
Returns the sum of all elements in the |
mean |
Returns the mean value of all elements in the |
amin |
Returns the minimum value of each slice of the |
amax |
Returns the maximum value of each slice of the |
argmin |
Returns the indices of the minimum value(s) of the flattened tensor or along a dimension |
argmax |
Returns the indices of the maximum value of all elements in the |
prod |
Returns the product of all elements in the |
all |
Tests if all elements in |
norm |
Returns the matrix norm or vector norm of a given tensor. |
var |
Calculates the variance over the dimensions specified by |
std |
Calculates the standard deviation over the dimensions specified by |
View and select functions¶
We’ve included a number of view and select functions as well; intuitively, these operators will apply to
both the data and the mask and then wrap the result in a MaskedTensor
. For a quick example,
consider select()
:
>>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> data
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]])
>>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
>>> mt = masked_tensor(data, mask)
>>> data.select(0, 1)
tensor([4., 5., 6., 7.])
>>> mask.select(0, 1)
tensor([False, True, False, False])
>>> mt.select(0, 1)
MaskedTensor(
[ --, 5.0000, --, --]
)
The following ops are currently supported:
atleast_1d |
Returns a 1-dimensional view of each input tensor with zero dimensions. |
broadcast_tensors |
Broadcasts the given tensors according to Broadcasting semantics. |
broadcast_to |
Broadcasts |
cat |
Concatenates the given sequence of tensors in |
chunk |
Attempts to split a tensor into the specified number of chunks. |
column_stack |
Creates a new tensor by horizontally stacking the tensors in |
dsplit |
Splits |
flatten |
Flattens |
hsplit |
Splits |
hstack |
Stack tensors in sequence horizontally (column wise). |
kron |
Computes the Kronecker product, denoted by , of |
meshgrid |
Creates grids of coordinates specified by the 1D inputs in attr:tensors. |
narrow |
Returns a new tensor that is a narrowed version of |
Extract sliding local blocks from a batched input tensor. |
|
ravel |
Return a contiguous flattened tensor. |
select |
Slices the |
split |
Splits the tensor into chunks. |
stack |
Concatenates a sequence of tensors along a new dimension. |
t |
Expects |
transpose |
Returns a tensor that is a transposed version of |
vsplit |
Splits |
vstack |
Stack tensors in sequence vertically (row wise). |
Returns a new view of the |
|
Expand this tensor to the same size as |
|
Returns a tensor with the same data and number of elements as |
|
Returns this tensor as the same shape as |
|
Returns a view of the original tensor which contains all slices of size |
|
Returns a new tensor with the same data as the |