Shortcuts

generate

torchtune.utils.generate(model: TransformerDecoder, prompt: Tensor, *, max_generated_tokens: int, pad_id: int = 0, temperature: float = 1.0, top_k: Optional[int] = None, stop_tokens: Optional[List[int]] = None, custom_generate_next_token: Optional[Callable] = None) List[List[int]][source]

Generates tokens from a model conditioned on a prompt.

Parameters:
  • model (TransformerDecoder) – model used for generation

  • prompt (torch.Tensor) – tensor with the token IDs associated with the given prompt, with shape either [seq_length] or [bsz x seq_length]

  • max_generated_tokens (int) – number of tokens to be generated

  • pad_id (int) – token ID to use for padding, default 0.

  • temperature (float) – value to scale the predicted logits by, default 1.0.

  • top_k (Optional[int]) – If specified, we prune the sampling to only token ids within the top_k probabilities, default None.

  • stop_tokens (Optional[List[int]]) – If specified, generation is stopped when any of these tokens are generated, default None.

  • custom_generate_next_token (Optional[Callable]) – If specified, we’ll use the custom_generate_next_token function. This is generally only useful if you want to specify a torch.compile version of the generate next token for performance reasons. If None, we use the default generate_next_token function. Default is None.

Examples

>>> model = torchtune.models.llama3.llama3_8b()
>>> tokenizer = torchtune.models.llama3.llama3_tokenizer()
>>> prompt = tokenizer("Hi my name is")
>>> output = generate(model, prompt, max_generated_tokens=100)
>>> print(tokenizer.decode(output[0]))
Hi my name is Jeremy and I'm a friendly language model assistant!
Returns:

collection of lists of generated tokens

Return type:

List[List[int]]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources