Shortcuts

sample

torchtune.generation.sample(logits: Tensor, *, temperature: float = 1.0, top_k: Optional[int] = None, q: Optional[Tensor] = None) Tensor[source]

Generic sample from a probability distribution. Includes support for Top-K sampling and Temperature.

Parameters:
  • logits (torch.Tensor) – logits from which to sample

  • temperature (float) – value to scale the predicted logits by, default 1.0.

  • top_k (Optional[int]) – If specified, we prune the sampling to only token ids within the top_k probabilities

  • q (Optional[torch.Tensor]) – randomly sampled tensor for softmax sampling trick. If None, we use the default softmax sampling trick. Default None.

Example

>>> from torchtune.generation import sample
>>> logits = torch.empty(3, 3).uniform_(0, 1)
>>> sample(logits)
tensor([[1],
        [2],
        [0]], dtype=torch.int32)
Returns:

sampled token id

Return type:

torch.Tensor

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources