torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=- 1)[source]

Samples from the Gumbel-Softmax distribution (Link 1 Link 2) and optionally discretizes.

  • logits[…, num_features] unnormalized log probabilities

  • tau – non-negative scalar temperature

  • hard – if True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd

  • dim (int) – A dimension along which softmax will be computed. Default: -1.


Sampled tensor of same shape as logits from the Gumbel-Softmax distribution. If hard=True, the returned samples will be one-hot, otherwise they will be probability distributions that sum to 1 across dim.


This function is here for legacy reasons, may be removed from nn.Functional in the future.


The main trick for hard is to do y_hard - y_soft.detach() + y_soft

It achieves two things: - makes the output value exactly one-hot (since we add then subtract y_soft value) - makes the gradient equal to y_soft gradient (since we strip all other gradients)

>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources