Shortcuts

torch.Tensor.scatter_reduce_

Tensor.scatter_reduce_(dim, index, src, reduce, *, include_self=True)Tensor

Reduces all values from the src tensor to the indices specified in the index tensor in the self tensor using the applied reduction defined via the reduce argument ("sum", "prod", "mean", "amax", "amin"). For each value in src, it is reduced to an index in self which is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim. If include_self="True", the values in the self tensor are included in the reduction.

self, index and src should all have the same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim. Note that index and src do not broadcast.

For a 3-D tensor with reduce="sum" and include_self=True the output is given as:

self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

Note

This operation may behave nondeterministically when given tensors on a CUDA device. See Reproducibility for more information.

Note

The backward pass is implemented only for src.shape == index.shape.

Warning

This function is in beta and may change in the near future.

Parameters
  • dim (int) – the axis along which to index

  • index (LongTensor) – the indices of elements to scatter and reduce.

  • src (Tensor) – the source elements to scatter and reduce

  • reduce (str) – the reduction operation to apply for non-unique indices ("sum", "prod", "mean", "amax", "amin")

  • include_self (bool) – whether elements from the self tensor are included in the reduction

Example:

>>> src = torch.tensor([1., 2., 3., 4., 5., 6.])
>>> index = torch.tensor([0, 1, 0, 1, 2, 1])
>>> input = torch.tensor([1., 2., 3., 4.])
>>> input.scatter_reduce(0, index, src, reduce="sum")
tensor([5., 14., 8., 4.])
>>> input.scatter_reduce(0, index, src, reduce="sum", include_self=False)
tensor([4., 12., 5., 4.])
>>> input2 = torch.tensor([5., 4., 3., 2.])
>>> input2.scatter_reduce(0, index, src, reduce="amax")
tensor([5., 6., 5., 2.])
>>> input2.scatter_reduce(0, index, src, reduce="amax", include_self=False)
tensor([3., 6., 5., 2.])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources