sample¶
- torchtune.generation.sample(logits: Tensor, *, temperature: float = 1.0, top_k: Optional[int] = None, q: Optional[Tensor] = None) Tensor [source]¶
Generic sample from a probability distribution. Includes support for Top-K sampling and Temperature.
- Parameters:
logits (torch.Tensor) – logits from which to sample
temperature (float) – value to scale the predicted logits by, default 1.0.
top_k (Optional[int]) – If specified, we prune the sampling to only token ids within the top_k probabilities
q (Optional[torch.Tensor]) – randomly sampled tensor for softmax sampling trick. If None, we use the default softmax sampling trick. Default None.
Example
>>> from torchtune.generation import sample >>> logits = torch.empty(3, 3).uniform_(0, 1) >>> sample(logits) tensor([[1], [2], [0]], dtype=torch.int32)
- Returns:
sampled token id
- Return type: