generate¶
- torchtune.generation.generate(model: TransformerDecoder, prompt: Tensor, *, max_generated_tokens: int, pad_id: int = 0, temperature: float = 1.0, top_k: Optional[int] = None, stop_tokens: Optional[List[int]] = None, rng: Optional[Generator] = None, custom_generate_next_token: Optional[Callable] = None) Tuple[Tensor, Tensor] [source]¶
Generates tokens from a model conditioned on a prompt, and also returns logits for the generations.
- Parameters:
model (TransformerDecoder) – model used for generation
prompt (torch.Tensor) – tensor with the token IDs associated with the given prompt, with shape either [seq_length] or [bsz x seq_length].
max_generated_tokens (int) – number of tokens to be generated
pad_id (int) – token ID to use for padding, default 0.
temperature (float) – value to scale the predicted logits by, default 1.0.
top_k (Optional[int]) – If specified, we prune the sampling to only token ids within the top_k probabilities, default None.
stop_tokens (Optional[List[int]]) – If specified, generation is stopped when any of these tokens are generated, default None.
rng (Optional[torch.Generator]) – random number generator, default None.
custom_generate_next_token (Optional[Callable]) – If specified, we’ll use the
custom_generate_next_token function
. This is generally only useful if you want to specify atorch.compile
version of the generate next token for performance reasons. If None, we use the defaultgenerate_next_token()
. Default is None.
Note
This function has only been tested with decoder-only models.
Examples
>>> model = torchtune.models.llama3.llama3_8b() >>> tokenizer = torchtune.models.llama3.llama3_tokenizer() >>> prompt = tokenizer.encode("Hi my name is") >>> rng.manual_seed(42) >>> output, logits = generate(model, torch.tensor(prompt), max_generated_tokens=100, pad_id=0) >>> print(tokenizer.decode(output[0].tolist())) Hi my name is Jeremy and I'm a friendly language model assistant!
- Returns:
- tuple of two tensors:
- tokens (torch.Tensor): tensor with the generated tokens,
with shape
[bsz x seq_len + num_generated_tokens]
wherenum_generated_tokens
may be less thanmax_generated_tokens
ifstop_tokens
are provided.
- logits (torch.Tensor): tensor with the logits associated with the generated tokens,
with shape
[bsz x seq_len + num_generated_tokens x vocab_size]
.
- Return type:
Tuple[torch.Tensor, torch.Tensor]