generate_next_token¶
- torchtune.generation.generate_next_token(model: TransformerDecoder, input_pos: Tensor, x: Tensor, q: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, temperature: float = 1.0, top_k: Optional[int] = None) Tuple[Tensor, Tensor] [source]¶
Generates the next tokens given a prompt, and also returns the corresponding logits.
- Parameters:
model (TransformerDecoder) – model used for generation
input_pos (torch.Tensor) – tensor with the positional encodings associated with the given prompt, with shape [bsz x seq_length].
x (torch.Tensor) – tensor with the token IDs associated with the given prompt, with shape [bsz x seq_length].
q (Optional[torch.Tensor]) – randomly sampled tensor for softmax sampling trick. See https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/generate.py#L40
mask (Optional[torch.Tensor]) – attention mask with shape [bsz x seq_length x seq_length], default None.
temperature (float) – value to scale the predicted logits by, default 1.0.
top_k (Optional[int]) – Top-k value to use for sampling, default None.
- Returns:
- tuple of two tensors:
- tokens (torch.Tensor): tensor with the generated tokens,
with shape [bsz x 1].
- logits (torch.Tensor): tensor with the logits associated with the generated tokens,
with shape [bsz x seq_length x vocab_size].
- Return type:
Tuple[torch.Tensor, torch.Tensor]