sample¶
- torchtune.generation.sample(logits: Tensor, q: Tensor, *, temperature: float = 1.0, top_k: Optional[int] = None) Tensor [source]¶
Generic sample from a probability distribution.
- Parameters:
logits (torch.Tensor) – logits from which to sample
q (torch.Tensor) – randomly sampled tensor for softmax sampling trick.
temperature (float) – value to scale the predicted logits by, default 1.0.
top_k (Optional[int]) – If specified, we prune the sampling to only token ids within the top_k probabilities
- Returns:
sampled token id
- Return type: