Shortcuts

torch.masked

Introduction

Motivation

Warning

The PyTorch API of masked tensors is in the prototype stage and may or may not change in the future.

MaskedTensor serves as an extension to torch.Tensor that provides the user with the ability to:

  • use any masked semantics (e.g. variable length tensors, nan* operators, etc.)

  • differentiate between 0 and NaN gradients

  • various sparse applications (see tutorial below)

“Specified” and “unspecified” have a long history in PyTorch without formal semantics and certainly without consistency; indeed, MaskedTensor was born out of a build up of issues that the vanilla torch.Tensor class could not properly address. Thus, a primary goal of MaskedTensor is to become the source of truth for said “specified” and “unspecified” values in PyTorch where they are a first class citizen instead of an afterthought. In turn, this should further unlock sparsity’s potential, enable safer and more consistent operators, and provide a smoother and more intuitive experience for users and developers alike.

What is a MaskedTensor?

A MaskedTensor is a tensor subclass that consists of 1) an input (data), and 2) a mask. The mask tells us which entries from the input should be included or ignored.

By way of example, suppose that we wanted to mask out all values that are equal to 0 (represented by the gray) and take the max:

_images/tensor_comparison.jpg

On top is the vanilla tensor example while the bottom is MaskedTensor where all the 0’s are masked out. This clearly yields a different result depending on whether we have the mask, but this flexible structure allows the user to systematically ignore any elements they’d like during computation.

There are already a number of existing tutorials that we’ve written to help users onboard, such as:

Supported Operators

Unary Operators

Unary operators are operators that only contain only a single input. Applying them to MaskedTensors is relatively straightforward: if the data is masked out at a given index, we apply the operator, otherwise we’ll continue to mask out the data.

The available unary operators are:

abs

Computes the absolute value of each element in input.

absolute

Alias for torch.abs()

acos

Computes the inverse cosine of each element in input.

arccos

Alias for torch.acos().

acosh

Returns a new tensor with the inverse hyperbolic cosine of the elements of input.

arccosh

Alias for torch.acosh().

angle

Computes the element-wise angle (in radians) of the given input tensor.

asin

Returns a new tensor with the arcsine of the elements of input.

arcsin

Alias for torch.asin().

asinh

Returns a new tensor with the inverse hyperbolic sine of the elements of input.

arcsinh

Alias for torch.asinh().

atan

Returns a new tensor with the arctangent of the elements of input.

arctan

Alias for torch.atan().

atanh

Returns a new tensor with the inverse hyperbolic tangent of the elements of input.

arctanh

Alias for torch.atanh().

bitwise_not

Computes the bitwise NOT of the given input tensor.

ceil

Returns a new tensor with the ceil of the elements of input, the smallest integer greater than or equal to each element.

clamp

Clamps all elements in input into the range [ min, max ].

clip

Alias for torch.clamp().

conj_physical

Computes the element-wise conjugate of the given input tensor.

cos

Returns a new tensor with the cosine of the elements of input.

cosh

Returns a new tensor with the hyperbolic cosine of the elements of input.

deg2rad

Returns a new tensor with each of the elements of input converted from angles in degrees to radians.

digamma

Alias for torch.special.digamma().

erf

Alias for torch.special.erf().

erfc

Alias for torch.special.erfc().

erfinv

Alias for torch.special.erfinv().

exp

Returns a new tensor with the exponential of the elements of the input tensor input.

exp2

Alias for torch.special.exp2().

expm1

Alias for torch.special.expm1().

fix

Alias for torch.trunc()

floor

Returns a new tensor with the floor of the elements of input, the largest integer less than or equal to each element.

frac

Computes the fractional portion of each element in input.

lgamma

Computes the natural logarithm of the absolute value of the gamma function on input.

log

Returns a new tensor with the natural logarithm of the elements of input.

log10

Returns a new tensor with the logarithm to the base 10 of the elements of input.

log1p

Returns a new tensor with the natural logarithm of (1 + input).

log2

Returns a new tensor with the logarithm to the base 2 of the elements of input.

logit

Alias for torch.special.logit().

i0

Alias for torch.special.i0().

isnan

Returns a new tensor with boolean elements representing if each element of input is NaN or not.

nan_to_num

Replaces NaN, positive infinity, and negative infinity values in input with the values specified by nan, posinf, and neginf, respectively.

neg

Returns a new tensor with the negative of the elements of input.

negative

Alias for torch.neg()

positive

Returns input.

pow

Takes the power of each element in input with exponent and returns a tensor with the result.

rad2deg

Returns a new tensor with each of the elements of input converted from angles in radians to degrees.

reciprocal

Returns a new tensor with the reciprocal of the elements of input

round

Rounds elements of input to the nearest integer.

rsqrt

Returns a new tensor with the reciprocal of the square-root of each of the elements of input.

sigmoid

Alias for torch.special.expit().

sign

Returns a new tensor with the signs of the elements of input.

sgn

This function is an extension of torch.sign() to complex tensors.

signbit

Tests if each element of input has its sign bit set or not.

sin

Returns a new tensor with the sine of the elements of input.

sinc

Alias for torch.special.sinc().

sinh

Returns a new tensor with the hyperbolic sine of the elements of input.

sqrt

Returns a new tensor with the square-root of the elements of input.

square

Returns a new tensor with the square of the elements of input.

tan

Returns a new tensor with the tangent of the elements of input.

tanh

Returns a new tensor with the hyperbolic tangent of the elements of input.

trunc

Returns a new tensor with the truncated integer values of the elements of input.

The available inplace unary operators are all of the above except:

angle

Computes the element-wise angle (in radians) of the given input tensor.

positive

Returns input.

signbit

Tests if each element of input has its sign bit set or not.

isnan

Returns a new tensor with boolean elements representing if each element of input is NaN or not.

Binary Operators

As you may have seen in the tutorial, MaskedTensor also has binary operations implemented with the caveat that the masks in the two MaskedTensors must match or else an error will be raised. As noted in the error, if you need support for a particular operator or have proposed semantics for how they should behave instead, please open an issue on GitHub. For now, we have decided to go with the most conservative implementation to ensure that users know exactly what is going on and are being intentional about their decisions with masked semantics.

The available binary operators are:

add

Adds other, scaled by alpha, to input.

atan2

Element-wise arctangent of inputi/otheri\text{input}_{i} / \text{other}_{i} with consideration of the quadrant.

arctan2

Alias for torch.atan2().

bitwise_and

Computes the bitwise AND of input and other.

bitwise_or

Computes the bitwise OR of input and other.

bitwise_xor

Computes the bitwise XOR of input and other.

bitwise_left_shift

Computes the left arithmetic shift of input by other bits.

bitwise_right_shift

Computes the right arithmetic shift of input by other bits.

div

Divides each element of the input input by the corresponding element of other.

divide

Alias for torch.div().

floor_divide

fmod

Applies C++'s std::fmod entrywise.

logaddexp

Logarithm of the sum of exponentiations of the inputs.

logaddexp2

Logarithm of the sum of exponentiations of the inputs in base-2.

mul

Multiplies input by other.

multiply

Alias for torch.mul().

nextafter

Return the next floating-point value after input towards other, elementwise.

remainder

Computes Python's modulus operation entrywise.

sub

Subtracts other, scaled by alpha, from input.

subtract

Alias for torch.sub().

true_divide

Alias for torch.div() with rounding_mode=None.

eq

Computes element-wise equality

ne

Computes inputother\text{input} \neq \text{other} element-wise.

le

Computes inputother\text{input} \leq \text{other} element-wise.

ge

Computes inputother\text{input} \geq \text{other} element-wise.

greater

Alias for torch.gt().

greater_equal

Alias for torch.ge().

gt

Computes input>other\text{input} > \text{other} element-wise.

less_equal

Alias for torch.le().

lt

Computes input<other\text{input} < \text{other} element-wise.

less

Alias for torch.lt().

maximum

Computes the element-wise maximum of input and other.

minimum

Computes the element-wise minimum of input and other.

fmax

Computes the element-wise maximum of input and other.

fmin

Computes the element-wise minimum of input and other.

not_equal

Alias for torch.ne().

The available inplace binary operators are all of the above except:

logaddexp

Logarithm of the sum of exponentiations of the inputs.

logaddexp2

Logarithm of the sum of exponentiations of the inputs in base-2.

equal

True if two tensors have the same size and elements, False otherwise.

fmin

Computes the element-wise minimum of input and other.

minimum

Computes the element-wise minimum of input and other.

fmax

Computes the element-wise maximum of input and other.

Reductions

The following reductions are available (with autograd support). For more information, the Overview tutorial details some examples of reductions, while the Advanced semantics tutorial has some further in-depth discussions about how we decided on certain reduction semantics.

sum

Returns the sum of all elements in the input tensor.

mean

amin

Returns the minimum value of each slice of the input tensor in the given dimension(s) dim.

amax

Returns the maximum value of each slice of the input tensor in the given dimension(s) dim.

argmin

Returns the indices of the minimum value(s) of the flattened tensor or along a dimension

argmax

Returns the indices of the maximum value of all elements in the input tensor.

prod

Returns the product of all elements in the input tensor.

all

Tests if all elements in input evaluate to True.

norm

Returns the matrix norm or vector norm of a given tensor.

var

Calculates the variance over the dimensions specified by dim.

std

Calculates the standard deviation over the dimensions specified by dim.

View and select functions

We’ve included a number of view and select functions as well; intuitively, these operators will apply to both the data and the mask and then wrap the result in a MaskedTensor. For a quick example, consider select():

>>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> data
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
>>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
>>> mt = masked_tensor(data, mask)
>>> data.select(0, 1)
tensor([4., 5., 6., 7.])
>>> mask.select(0, 1)
tensor([False,  True, False, False])
>>> mt.select(0, 1)
MaskedTensor(
  [      --,   5.0000,       --,       --]
)

The following ops are currently supported:

atleast_1d

Returns a 1-dimensional view of each input tensor with zero dimensions.

broadcast_tensors

Broadcasts the given tensors according to Broadcasting semantics.

broadcast_to

Broadcasts input to the shape shape.

cat

Concatenates the given sequence of tensors in tensors in the given dimension.

chunk

Attempts to split a tensor into the specified number of chunks.

column_stack

Creates a new tensor by horizontally stacking the tensors in tensors.

dsplit

Splits input, a tensor with three or more dimensions, into multiple tensors depthwise according to indices_or_sections.

flatten

Flattens input by reshaping it into a one-dimensional tensor.

hsplit

Splits input, a tensor with one or more dimensions, into multiple tensors horizontally according to indices_or_sections.

hstack

Stack tensors in sequence horizontally (column wise).

kron

Computes the Kronecker product, denoted by \otimes, of input and other.

meshgrid

Creates grids of coordinates specified by the 1D inputs in attr:tensors.

narrow

Returns a new tensor that is a narrowed version of input tensor.

nn.functional.unfold

Extract sliding local blocks from a batched input tensor.

ravel

Return a contiguous flattened tensor.

select

Slices the input tensor along the selected dimension at the given index.

split

Splits the tensor into chunks.

stack

Concatenates a sequence of tensors along a new dimension.

t

Expects input to be <= 2-D tensor and transposes dimensions 0 and 1.

transpose

Returns a tensor that is a transposed version of input.

vsplit

Splits input, a tensor with two or more dimensions, into multiple tensors vertically according to indices_or_sections.

vstack

Stack tensors in sequence vertically (row wise).

Tensor.expand

Returns a new view of the self tensor with singleton dimensions expanded to a larger size.

Tensor.expand_as

Expand this tensor to the same size as other.

Tensor.reshape

Returns a tensor with the same data and number of elements as self but with the specified shape.

Tensor.reshape_as

Returns this tensor as the same shape as other.

Tensor.unfold

Returns a view of the original tensor which contains all slices of size size from self tensor in the dimension dimension.

Tensor.view

Returns a new tensor with the same data as the self tensor but of a different shape.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources