Shortcuts

Source code for torchtune.generation._generation

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, List, Optional, Tuple

import torch

from torchtune.modules.transformer import TransformerDecoder


def multinomial_sample_one(probs: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
    """Samples from a multinomial distribution."""
    return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)


[docs]def sample( logits: torch.Tensor, *, temperature: float = 1.0, top_k: Optional[int] = None, q: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Generic sample from a probability distribution. Includes support for Top-K sampling and Temperature. Args: logits (torch.Tensor): logits from which to sample temperature (float): value to scale the predicted logits by, default 1.0. top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities q (Optional[torch.Tensor]): randomly sampled tensor for softmax sampling trick. If None, we use the default softmax sampling trick. Default None. Example: >>> from torchtune.generation import sample >>> logits = torch.empty(3, 3).uniform_(0, 1) >>> sample(logits) tensor([[1], [2], [0]], dtype=torch.int32) Returns: torch.Tensor: sampled token id """ # scale the logits based on temperature logits = logits / max(temperature, 1e-5) if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) # select the very last value from the top_k above as the pivot pivot = v.select(-1, -1).unsqueeze(-1) # set everything smaller than pivot value to inf since these # should be pruned logits = torch.where(logits < pivot, -float("Inf"), logits) # change logits into probabilities probs = torch.nn.functional.softmax(logits, dim=-1) # if q is None, we use the default softmax sampling trick if q is None: # alternative to torch.empty_like(probs).exponential_(1) # so it is reproducible in stable and nightly uniform_val = torch.rand_like(probs) epsilon = torch.finfo(uniform_val.dtype).eps / 2 condition = uniform_val >= 1.0 - epsilon q = -torch.where(condition, -epsilon, torch.log(uniform_val)) return multinomial_sample_one(probs, q)
[docs]def generate_next_token( model: TransformerDecoder, input_pos: torch.Tensor, x: torch.Tensor, q: Optional[torch.Tensor] = None, *, mask: Optional[torch.Tensor] = None, temperature: float = 1.0, top_k: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Generates the next tokens given a prompt, and also returns the corresponding logits. Args: model (TransformerDecoder): model used for generation input_pos (torch.Tensor): tensor with the positional encodings associated with the given prompt, with shape [bsz x seq_length]. x (torch.Tensor): tensor with the token IDs associated with the given prompt, with shape [bsz x seq_length]. q (Optional[torch.Tensor]): randomly sampled tensor for softmax sampling trick. See https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/generate.py#L40 mask (Optional[torch.Tensor]): attention mask with shape [bsz x seq_length x seq_length], default None. temperature (float): value to scale the predicted logits by, default 1.0. top_k (Optional[int]): Top-k value to use for sampling, default None. Returns: Tuple[torch.Tensor, torch.Tensor]: tuple of two tensors: - tokens (torch.Tensor): tensor with the generated tokens, with shape [bsz x 1]. - logits (torch.Tensor): tensor with the logits associated with the generated tokens, with shape [bsz x 1 x vocab_size]. """ # model produces logits in [bsz, seq_length, vocab_size] # we want to take the last token's logits as the input to the next model call logits = model(x, input_pos=input_pos, mask=mask)[:, -1] return ( sample(logits.clone(), temperature=temperature, top_k=top_k, q=q), logits.unsqueeze(1), )
def update_stop_tokens_tracker( tokens: torch.Tensor, stop_tokens: torch.Tensor, stop_token_reached: torch.Tensor ) -> torch.Tensor: """Updates which sequences have reached a stop token.""" # tokens: [bsz, 1] # stop_tokens: [num_stop_tokens] # stop_token_reached: [bsz] stop_token_reached_curr = torch.isin(tokens, stop_tokens).flatten() stop_token_reached |= stop_token_reached_curr return stop_token_reached
[docs]def get_causal_mask_from_padding_mask( padding_mask: torch.Tensor, target_seq_len: Optional[int] = None ) -> torch.Tensor: """ Converts a padding mask of shape ``[bsz, seq_len]`` to a ``[bsz, seq_len, seq_len]`` causal attention mask suitable for consumption by :func:`~torch.nn.functional.scaled_dot_product_attention`. If ``target_seq_len`` is provided, this will return a mask of shape ``[bsz, seq_len, target_seq_len]``. This is useful when generating masks for static KV caches where the maximum length the caches have been setup with are longer than the current sequence. Args: padding_mask (torch.Tensor): Boolean tensor where False indicates the corresponding token in the sequence is a padding token and should be masked out in attention, with shape [bsz x seq_length] target_seq_len (Optional[int]): target sequence length to create attention mask with. Default None. Returns: torch.Tensor: Boolean causal mask with shape - [bsz, seq_length, seq_length] or - [bsz, seq_length, target_seq_len] if ``target_seq_len`` was specified. Raises: AssertionError: if ``target_seq_len < seq_len``, the sequence length of the padding mask. Example: >>> padding_mask = torch.tensor([[False, True, True, True]]) >>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5) tensor([[[ True, False, False, False, False], [False, True, False, False, False], [False, True, True, False, False], [False, True, True, True, False]]]) ]) """ bsz, seq_len = padding_mask.shape target_seq_len = seq_len if target_seq_len is None else target_seq_len if target_seq_len < seq_len: raise AssertionError( "target_seq_len cannot be shorter than the sequence length of the padding mask." ) mask = torch.tril( torch.ones(seq_len, target_seq_len, device=padding_mask.device, dtype=bool), diagonal=0, ).repeat(bsz, 1, 1) mask.narrow(2, 0, seq_len).mul_(padding_mask[:, None, :].expand(-1, seq_len, -1)) mask.diagonal(dim1=1, dim2=2).copy_(torch.Tensor([True])) return mask
[docs]def get_position_ids_from_padding_mask( padding_mask: torch.Tensor, ): """ Calculates position ids given a padding mask which right-shifts position ids to start from the first valid token. Args: padding_mask (torch.Tensor): Boolean tensor where False indicates the corresponding token in the sequence is a padding token and should be masked out in attention. Shape [bsz, seq_len] Returns: torch.Tensor: position ids which are appropriately shifted according to any padding values. Example: >>> padding_mask = torch.tensor([False, False, False, True, True, True, True, True]) >>> get_position_ids_from_padding_mask(padding_mask) torch.Tensor([0, 0, 0, 0, 1, 2, 3, 4]) """ return ((padding_mask.cumsum(-1) - 1) * padding_mask).to(torch.int)
[docs]@torch.no_grad() def generate( model: TransformerDecoder, prompt: torch.Tensor, *, max_generated_tokens: int, pad_id: int = 0, temperature: float = 1.0, top_k: Optional[int] = None, stop_tokens: Optional[List[int]] = None, rng: Optional[torch.Generator] = None, custom_generate_next_token: Optional[Callable] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Generates tokens from a model conditioned on a prompt, and also returns logits for the generations. Args: model (TransformerDecoder): model used for generation prompt (torch.Tensor): tensor with the token IDs associated with the given prompt, with shape either [seq_length] or [bsz x seq_length]. max_generated_tokens (int): number of tokens to be generated pad_id (int): token ID to use for padding, default 0. temperature (float): value to scale the predicted logits by, default 1.0. top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities, default None. stop_tokens (Optional[List[int]]): If specified, generation is stopped when any of these tokens are generated, default None. rng (Optional[torch.Generator]): random number generator, default None. custom_generate_next_token (Optional[Callable]): If specified, we'll use the ``custom_generate_next_token function``. This is generally only useful if you want to specify a ``torch.compile`` version of the generate next token for performance reasons. If None, we use the default :func:`generate_next_token`. Default is None. Note: This function has only been tested with decoder-only models. Examples: >>> model = torchtune.models.llama3.llama3_8b() >>> tokenizer = torchtune.models.llama3.llama3_tokenizer() >>> prompt = tokenizer.encode("Hi my name is") >>> rng.manual_seed(42) >>> output, logits = generate(model, torch.tensor(prompt), max_generated_tokens=100, pad_id=0) >>> print(tokenizer.decode(output[0].tolist())) Hi my name is Jeremy and I'm a friendly language model assistant! Returns: Tuple[torch.Tensor, torch.Tensor]: tuple of two tensors: - tokens (torch.Tensor): tensor with the generated tokens, with shape ``[bsz x seq_len + num_generated_tokens]`` where ``num_generated_tokens`` may be less than ``max_generated_tokens`` if ``stop_tokens`` are provided. - logits (torch.Tensor): tensor with the logits associated with the generated tokens, with shape ``[bsz x num_generated_tokens x vocab_size]``. """ prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt if custom_generate_next_token is None: custom_generate_next_token = generate_next_token bsz, prompt_length = prompt.size() total_response_length = prompt_length + max_generated_tokens generated_tokens = prompt.clone() incremental_decoding = model.caches_are_enabled() # grab the correct max_seq_len to generate full causal masks/position ids # this is the model's max cache len if incremental decoding, or the sequence # length otherwise max_seq_len = ( total_response_length if not incremental_decoding else model.decoder_max_cache_seq_len ) padding_masks = generated_tokens != pad_id if not padding_masks.all(): # we have padding in the prompt due to varying-length sequences in a batch # extend padding masks out to the correct seq len padding_masks = torch.nn.functional.pad( padding_masks, (0, max_generated_tokens), value=True ) # generate the full causal mask for the whole padding mask with padding ignored masks = get_causal_mask_from_padding_mask( padding_masks, target_seq_len=max_seq_len ) # right-shift position IDs to account for padding input_pos = get_position_ids_from_padding_mask(padding_masks) else: # just use a regular causal mask if there is no padding masks = torch.tril( torch.ones( total_response_length, max_seq_len, dtype=torch.bool, device=prompt.device, ) ).unsqueeze(0) input_pos = torch.arange( 0, total_response_length, device=generated_tokens.device ).unsqueeze(0) if incremental_decoding: # if KV-caches are enabled, we need a causal mask of shape [bsz, prompt_length, max_cache_len] # to match the key/value cache tensor shapes curr_masks = masks[:, :prompt_length] else: # otherwise the causal mask is shape [bsz, prompt_length, prompt_length] because key/value # tensors are of identical shape to the prompt curr_masks = masks[:, :prompt_length, :prompt_length] q = None if rng is not None: uniform_val = torch.rand( bsz, model.tok_embeddings.num_embeddings, generator=rng, device=prompt.device, ) epsilon = torch.finfo(uniform_val.dtype).eps / 2 condition = uniform_val >= 1.0 - epsilon q = -torch.where(condition, -epsilon, torch.log(uniform_val)) tokens, generated_logits = generate_next_token( model, input_pos=input_pos[:, :prompt_length].squeeze(), mask=curr_masks, x=prompt, temperature=temperature, top_k=top_k, q=q, ) generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) curr_pos = prompt_length # keeps track at a high level if we've already hit a stop token in a sequence so we can early stop stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device) stop_tokens = ( torch.tensor(stop_tokens, device=prompt.device, dtype=tokens.dtype) if stop_tokens else None ) # everything in stop_token_mask starts as 1s, and we'll set them to 0 for sequences # that already hit a stop token stop_token_mask = torch.ones( (bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device ) # stop early if we reach a stop token in every seq if stop_tokens is not None: stop_token_reached = update_stop_tokens_tracker( tokens, stop_tokens, stop_token_reached ) if stop_token_reached.all().item(): return generated_tokens, generated_logits for _ in range(max_generated_tokens - 1): # update stop_token_mask if we reached a stop token in a previous step # by appending the logical not of stop_token_reached to the end of the mask # reshaped to be bsz first if stop_tokens is not None: stop_token_mask = torch.cat( [stop_token_mask, ~stop_token_reached.reshape(bsz, 1)], dim=-1 ) # if incremental decoding is enabled, we can use the current position # otherwise, we take the whole sequence up to the current position if incremental_decoding: curr_input_pos = input_pos[:, curr_pos].contiguous() curr_masks = masks[:, curr_pos, None, :].contiguous() else: tokens = generated_tokens.clone() curr_input_pos = input_pos[:, : curr_pos + 1] curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1] q = None if rng is not None: uniform_val = torch.rand( bsz, model.tok_embeddings.num_embeddings, generator=rng, device=prompt.device, ) epsilon = torch.finfo(uniform_val.dtype).eps / 2 condition = uniform_val >= 1.0 - epsilon q = -torch.where(condition, -epsilon, torch.log(uniform_val)) tokens, logits = custom_generate_next_token( model, input_pos=curr_input_pos, x=tokens.clone(), mask=curr_masks, temperature=temperature, top_k=top_k, q=q, ) generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) generated_logits = torch.cat([generated_logits, logits], dim=1) curr_pos += 1 if stop_tokens is not None: stop_token_reached = update_stop_tokens_tracker( tokens, stop_tokens, stop_token_reached ) if stop_token_reached.all(): break # mask out generated tokens in seqs that already hit a stop token if stop_tokens is not None: generated_tokens *= stop_token_mask generated_logits *= stop_token_mask[:, -generated_logits.shape[1] :, None] return generated_tokens, generated_logits

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources