Shortcuts

torch.multinomial

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor

Returns a tensor where each row contains num_samples indices sampled from the multinomial probability distribution located in the corresponding row of tensor input.

Note

The rows of input do not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum.

Indices are ordered from left to right according to when each was sampled (first samples are placed in first column).

If input is a vector, out is a vector of size num_samples.

If input is a matrix with m rows, out is an matrix of shape (m×num_samples)(m \times \text{num\_samples}).

If replacement is True, samples are drawn with replacement.

If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.

Note

When drawn without replacement, num_samples must be lower than number of non-zero elements in input (or the min number of non-zero elements in each row of input if it is a matrix).

Parameters:
  • input (Tensor) – the input tensor containing probabilities

  • num_samples (int) – number of samples to draw

  • replacement (bool, optional) – whether to draw with replacement or not

Keyword Arguments:
  • generator (torch.Generator, optional) – a pseudorandom number generator for sampling

  • out (Tensor, optional) – the output tensor.

Example:

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 4) # ERROR!
RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False,
not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources