Source code for torchtune.generation._generation
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, List, Optional, Tuple
import torch
from torchtune.modules.transformer import TransformerDecoder
def multinomial_sample_one(probs: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
"""Samples from a multinomial distribution."""
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
[docs]def sample(
logits: torch.Tensor,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
q: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Generic sample from a probability distribution. Includes support for Top-K sampling
and Temperature.
Args:
logits (torch.Tensor): logits from which to sample
temperature (float): value to scale the predicted logits by, default 1.0.
top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities
q (Optional[torch.Tensor]): randomly sampled tensor for softmax sampling trick. If None,
we use the default softmax sampling trick. Default None.
Example:
>>> from torchtune.generation import sample
>>> logits = torch.empty(3, 3).uniform_(0, 1)
>>> sample(logits)
tensor([[1],
[2],
[0]], dtype=torch.int32)
Returns:
torch.Tensor: sampled token id
"""
# scale the logits based on temperature
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
# select the very last value from the top_k above as the pivot
pivot = v.select(-1, -1).unsqueeze(-1)
# set everything smaller than pivot value to inf since these
# should be pruned
logits = torch.where(logits < pivot, -float("Inf"), logits)
# change logits into probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# if q is None, we use the default softmax sampling trick
if q is None:
# alternative to torch.empty_like(probs).exponential_(1)
# so it is reproducible in stable and nightly
uniform_val = torch.rand_like(probs)
epsilon = torch.finfo(uniform_val.dtype).eps / 2
condition = uniform_val >= 1.0 - epsilon
q = -torch.where(condition, -epsilon, torch.log(uniform_val))
return multinomial_sample_one(probs, q)
[docs]def generate_next_token(
model: TransformerDecoder,
input_pos: torch.Tensor,
x: torch.Tensor,
q: Optional[torch.Tensor] = None,
*,
mask: Optional[torch.Tensor] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generates the next tokens given a prompt, and also returns the corresponding logits.
Args:
model (TransformerDecoder): model used for generation
input_pos (torch.Tensor): tensor with the positional encodings associated with the given prompt,
with shape [bsz x seq_length].
x (torch.Tensor): tensor with the token IDs associated with the given prompt,
with shape [bsz x seq_length].
q (Optional[torch.Tensor]): randomly sampled tensor for softmax sampling trick.
See https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/generate.py#L40
mask (Optional[torch.Tensor]): attention mask with shape [bsz x seq_length x seq_length],
default None.
temperature (float): value to scale the predicted logits by, default 1.0.
top_k (Optional[int]): Top-k value to use for sampling, default None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: tuple of two tensors:
- tokens (torch.Tensor): tensor with the generated tokens,
with shape [bsz x 1].
- logits (torch.Tensor): tensor with the logits associated with the generated tokens,
with shape [bsz x 1 x vocab_size].
"""
# model produces logits in [bsz, seq_length, vocab_size]
# we want to take the last token's logits as the input to the next model call
logits = model(x, input_pos=input_pos, mask=mask)[:, -1]
return (
sample(logits.clone(), temperature=temperature, top_k=top_k, q=q),
logits.unsqueeze(1),
)
def update_stop_tokens_tracker(
tokens: torch.Tensor, stop_tokens: torch.Tensor, stop_token_reached: torch.Tensor
) -> torch.Tensor:
"""Updates which sequences have reached a stop token."""
# tokens: [bsz, 1]
# stop_tokens: [num_stop_tokens]
# stop_token_reached: [bsz]
stop_token_reached_curr = torch.isin(tokens, stop_tokens).flatten()
stop_token_reached |= stop_token_reached_curr
return stop_token_reached
[docs]def get_causal_mask_from_padding_mask(
padding_mask: torch.Tensor, target_seq_len: Optional[int] = None
) -> torch.Tensor:
"""
Converts a padding mask of shape ``[bsz, seq_len]`` to a ``[bsz, seq_len, seq_len]`` causal attention mask suitable for
consumption by :func:`~torch.nn.functional.scaled_dot_product_attention`. If ``target_seq_len``
is provided, this will return a mask of shape ``[bsz, seq_len, target_seq_len]``. This is useful
when generating masks for static KV caches where the maximum length the caches have been setup with
are longer than the current sequence.
Args:
padding_mask (torch.Tensor): Boolean tensor where False indicates the corresponding token in the sequence
is a padding token and should be masked out in attention, with shape [bsz x seq_length]
target_seq_len (Optional[int]): target sequence length to create attention mask with. Default None.
Returns:
torch.Tensor: Boolean causal mask with shape
- [bsz, seq_length, seq_length] or
- [bsz, seq_length, target_seq_len] if ``target_seq_len`` was specified.
Raises:
AssertionError: if ``target_seq_len < seq_len``, the sequence length of the padding mask.
Example:
>>> padding_mask = torch.tensor([[False, True, True, True]])
>>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5)
tensor([[[ True, False, False, False, False],
[False, True, False, False, False],
[False, True, True, False, False],
[False, True, True, True, False]]])
])
"""
bsz, seq_len = padding_mask.shape
target_seq_len = seq_len if target_seq_len is None else target_seq_len
if target_seq_len < seq_len:
raise AssertionError(
"target_seq_len cannot be shorter than the sequence length of the padding mask."
)
mask = torch.tril(
torch.ones(seq_len, target_seq_len, device=padding_mask.device, dtype=bool),
diagonal=0,
).repeat(bsz, 1, 1)
mask.narrow(2, 0, seq_len).mul_(padding_mask[:, None, :].expand(-1, seq_len, -1))
mask.diagonal(dim1=1, dim2=2).copy_(torch.Tensor([True]))
return mask
[docs]def get_position_ids_from_padding_mask(
padding_mask: torch.Tensor,
):
"""
Calculates position ids given a padding mask which right-shifts position ids to start
from the first valid token.
Args:
padding_mask (torch.Tensor): Boolean tensor where False indicates the corresponding token in the sequence
is a padding token and should be masked out in attention. Shape [bsz, seq_len]
Returns:
torch.Tensor: position ids which are appropriately shifted according to any padding values.
Example:
>>> padding_mask = torch.tensor([False, False, False, True, True, True, True, True])
>>> get_position_ids_from_padding_mask(padding_mask)
torch.Tensor([0, 0, 0, 0, 1, 2, 3, 4])
"""
return ((padding_mask.cumsum(-1) - 1) * padding_mask).to(torch.int)
[docs]@torch.no_grad()
def generate(
model: TransformerDecoder,
prompt: torch.Tensor,
*,
max_generated_tokens: int,
pad_id: int = 0,
temperature: float = 1.0,
top_k: Optional[int] = None,
stop_tokens: Optional[List[int]] = None,
rng: Optional[torch.Generator] = None,
custom_generate_next_token: Optional[Callable] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generates tokens from a model conditioned on a prompt, and also returns logits for the generations.
Args:
model (TransformerDecoder): model used for generation
prompt (torch.Tensor): tensor with the token IDs associated with the given prompt,
with shape either [seq_length] or [bsz x seq_length].
max_generated_tokens (int): number of tokens to be generated
pad_id (int): token ID to use for padding, default 0.
temperature (float): value to scale the predicted logits by, default 1.0.
top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities,
default None.
stop_tokens (Optional[List[int]]): If specified, generation is stopped when any of these tokens are generated,
default None.
rng (Optional[torch.Generator]): random number generator, default None.
custom_generate_next_token (Optional[Callable]): If specified, we'll use the
``custom_generate_next_token function``. This is generally only useful if
you want to specify a ``torch.compile`` version of the generate next token for
performance reasons. If None, we use the default :func:`generate_next_token`.
Default is None.
Note:
This function has only been tested with decoder-only models.
Examples:
>>> model = torchtune.models.llama3.llama3_8b()
>>> tokenizer = torchtune.models.llama3.llama3_tokenizer()
>>> prompt = tokenizer.encode("Hi my name is")
>>> rng.manual_seed(42)
>>> output, logits = generate(model, torch.tensor(prompt), max_generated_tokens=100, pad_id=0)
>>> print(tokenizer.decode(output[0].tolist()))
Hi my name is Jeremy and I'm a friendly language model assistant!
Returns:
Tuple[torch.Tensor, torch.Tensor]: tuple of two tensors:
- tokens (torch.Tensor): tensor with the generated tokens,
with shape ``[bsz x seq_len + num_generated_tokens]`` where ``num_generated_tokens``
may be less than ``max_generated_tokens`` if ``stop_tokens`` are provided.
- logits (torch.Tensor): tensor with the logits associated with the generated tokens,
with shape ``[bsz x num_generated_tokens x vocab_size]``.
"""
prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt
if custom_generate_next_token is None:
custom_generate_next_token = generate_next_token
bsz, prompt_length = prompt.size()
total_response_length = prompt_length + max_generated_tokens
generated_tokens = prompt.clone()
incremental_decoding = model.caches_are_enabled()
# grab the correct max_seq_len to generate full causal masks/position ids
# this is the model's max cache len if incremental decoding, or the sequence
# length otherwise
max_seq_len = (
total_response_length
if not incremental_decoding
else model.decoder_max_cache_seq_len
)
padding_masks = generated_tokens != pad_id
if not padding_masks.all():
# we have padding in the prompt due to varying-length sequences in a batch
# extend padding masks out to the correct seq len
padding_masks = torch.nn.functional.pad(
padding_masks, (0, max_generated_tokens), value=True
)
# generate the full causal mask for the whole padding mask with padding ignored
masks = get_causal_mask_from_padding_mask(
padding_masks, target_seq_len=max_seq_len
)
# right-shift position IDs to account for padding
input_pos = get_position_ids_from_padding_mask(padding_masks)
else:
# just use a regular causal mask if there is no padding
masks = torch.tril(
torch.ones(
total_response_length,
max_seq_len,
dtype=torch.bool,
device=prompt.device,
)
).unsqueeze(0)
input_pos = torch.arange(
0, total_response_length, device=generated_tokens.device
).unsqueeze(0)
if incremental_decoding:
# if KV-caches are enabled, we need a causal mask of shape [bsz, prompt_length, max_cache_len]
# to match the key/value cache tensor shapes
curr_masks = masks[:, :prompt_length]
else:
# otherwise the causal mask is shape [bsz, prompt_length, prompt_length] because key/value
# tensors are of identical shape to the prompt
curr_masks = masks[:, :prompt_length, :prompt_length]
q = None
if rng is not None:
uniform_val = torch.rand(
bsz,
model.tok_embeddings.num_embeddings,
generator=rng,
device=prompt.device,
)
epsilon = torch.finfo(uniform_val.dtype).eps / 2
condition = uniform_val >= 1.0 - epsilon
q = -torch.where(condition, -epsilon, torch.log(uniform_val))
tokens, generated_logits = generate_next_token(
model,
input_pos=input_pos[:, :prompt_length].squeeze(),
mask=curr_masks,
x=prompt,
temperature=temperature,
top_k=top_k,
q=q,
)
generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
curr_pos = prompt_length
# keeps track at a high level if we've already hit a stop token in a sequence so we can early stop
stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device)
stop_tokens = (
torch.tensor(stop_tokens, device=prompt.device, dtype=tokens.dtype)
if stop_tokens
else None
)
# everything in stop_token_mask starts as 1s, and we'll set them to 0 for sequences
# that already hit a stop token
stop_token_mask = torch.ones(
(bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device
)
# stop early if we reach a stop token in every seq
if stop_tokens is not None:
stop_token_reached = update_stop_tokens_tracker(
tokens, stop_tokens, stop_token_reached
)
if stop_token_reached.all().item():
return generated_tokens, generated_logits
for _ in range(max_generated_tokens - 1):
# update stop_token_mask if we reached a stop token in a previous step
# by appending the logical not of stop_token_reached to the end of the mask
# reshaped to be bsz first
if stop_tokens is not None:
stop_token_mask = torch.cat(
[stop_token_mask, ~stop_token_reached.reshape(bsz, 1)], dim=-1
)
# if incremental decoding is enabled, we can use the current position
# otherwise, we take the whole sequence up to the current position
if incremental_decoding:
curr_input_pos = input_pos[:, curr_pos].contiguous()
curr_masks = masks[:, curr_pos, None, :].contiguous()
else:
tokens = generated_tokens.clone()
curr_input_pos = input_pos[:, : curr_pos + 1]
curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1]
q = None
if rng is not None:
uniform_val = torch.rand(
bsz,
model.tok_embeddings.num_embeddings,
generator=rng,
device=prompt.device,
)
epsilon = torch.finfo(uniform_val.dtype).eps / 2
condition = uniform_val >= 1.0 - epsilon
q = -torch.where(condition, -epsilon, torch.log(uniform_val))
tokens, logits = custom_generate_next_token(
model,
input_pos=curr_input_pos,
x=tokens.clone(),
mask=curr_masks,
temperature=temperature,
top_k=top_k,
q=q,
)
generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
generated_logits = torch.cat([generated_logits, logits], dim=1)
curr_pos += 1
if stop_tokens is not None:
stop_token_reached = update_stop_tokens_tracker(
tokens, stop_tokens, stop_token_reached
)
if stop_token_reached.all():
break
# mask out generated tokens in seqs that already hit a stop token
if stop_tokens is not None:
generated_tokens *= stop_token_mask
generated_logits *= stop_token_mask[:, -generated_logits.shape[1] :, None]
return generated_tokens, generated_logits