Shortcuts

Source code for torchaudio.models.rnnt_decoder

from typing import Callable, Dict, List, Optional, NamedTuple, Tuple

import torch
from torchaudio.models import RNNT


__all__ = ["Hypothesis", "RNNTBeamSearch"]


[docs]class Hypothesis(NamedTuple): r"""Represents hypothesis generated by beam search decoder ``RNNTBeamSearch``. :ivar List[int] tokens: Predicted sequence of tokens. :ivar torch.Tensor predictor_out: Prediction network output. :ivar List[List[torch.Tensor]] state: Prediction network internal state. :ivar float score: Score of hypothesis. :ivar List[int] alignment: Sequence of timesteps, with the i-th value mapping to the i-th predicted token in ``tokens``. :ivar int blank: Token index corresponding to blank token. :ivar str key: Value used to determine equivalence in token sequences between ``Hypothesis`` instances. """ tokens: List[int] predictor_out: torch.Tensor state: List[List[torch.Tensor]] score: float alignment: List[int] blank: int key: str
def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]: states: List[List[torch.Tensor]] = [] for i in range(len(hypos[0].state)): batched_state_components: List[torch.Tensor] = [] for j in range(len(hypos[0].state[i])): batched_state_components.append(torch.cat([hypo.state[i][j] for hypo in hypos])) states.append(batched_state_components) return states def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]: idx_tensor = torch.tensor([idx], device=device) return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states] def _default_hypo_sort_key(hypo: Hypothesis) -> float: return hypo.score / (len(hypo.tokens) + 1) def _compute_updated_scores( hypos: List[Hypothesis], next_token_probs: torch.Tensor, beam_width: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: hypo_scores = torch.tensor([h.score for h in hypos]).unsqueeze(1) nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1] nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width) nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc") nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1] return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None: for i, elem in enumerate(hypo_list): if hypo.key == elem.key: del hypo_list[i] break
[docs]class RNNTBeamSearch(torch.nn.Module): r"""Beam search decoder for RNN-T model. Args: model (RNNT): RNN-T model to use. blank (int): index of blank token in vocabulary. temperature (float, optional): temperature to apply to joint network output. Larger values yield more uniform samples. (Default: 1.0) hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns hypothesis score normalized by token sequence length. (Default: None) step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100) """ def __init__( self, model: RNNT, blank: int, temperature: float = 1.0, hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None, step_max_tokens: int = 100, ) -> None: super().__init__() self.model = model self.blank = blank self.temperature = temperature if hypo_sort_key is None: self.hypo_sort_key = _default_hypo_sort_key else: self.hypo_sort_key = hypo_sort_key self.step_max_tokens = step_max_tokens def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]: if hypo is not None: token = hypo.tokens[-1] state = hypo.state else: token = self.blank state = None one_tensor = torch.tensor([1], device=device) pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state) init_hypo = Hypothesis( tokens=[token], predictor_out=pred_out[0].detach(), state=pred_state, score=0.0, alignment=[-1], blank=self.blank, key=str([token]), ) return [init_hypo] def _gen_next_token_probs( self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device ) -> torch.Tensor: one_tensor = torch.tensor([1], device=device) predictor_out = torch.stack([h.predictor_out for h in hypos], dim=0) joined_out, _, _ = self.model.join( enc_out, one_tensor, predictor_out, torch.tensor([1] * len(hypos), device=device), ) # [beam_width, 1, 1, num_tokens] joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3) return joined_out[:, 0, 0] def _gen_b_hypos( self, b_hypos: List[Hypothesis], a_hypos: List[Hypothesis], next_token_probs: torch.Tensor, key_to_b_hypo: Dict[str, Hypothesis], ) -> List[Hypothesis]: for i in range(len(a_hypos)): h_a = a_hypos[i] append_blank_score = h_a.score + next_token_probs[i, -1] if h_a.key in key_to_b_hypo: h_b = key_to_b_hypo[h_a.key] _remove_hypo(h_b, b_hypos) score = float(torch.tensor(h_b.score).logaddexp(append_blank_score)) alignment = h_a.alignment if h_b.score < h_a.score else h_b.alignment else: score = float(append_blank_score) alignment = h_a.alignment h_b = Hypothesis( tokens=h_a.tokens, predictor_out=h_a.predictor_out, state=h_a.state, score=score, alignment=alignment, blank=self.blank, key=h_a.key, ) b_hypos.append(h_b) key_to_b_hypo[h_b.key] = h_b _, sorted_idx = torch.tensor([hypo.score for hypo in b_hypos]).sort() return [b_hypos[idx] for idx in sorted_idx] def _gen_a_hypos( self, a_hypos: List[Hypothesis], b_hypos: List[Hypothesis], next_token_probs: torch.Tensor, t: int, beam_width: int, device: torch.device, ) -> List[Hypothesis]: ( nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token, ) = _compute_updated_scores(a_hypos, next_token_probs, beam_width) if len(b_hypos) < beam_width: b_nbest_score = -float("inf") else: b_nbest_score = b_hypos[-beam_width].score base_hypos: List[Hypothesis] = [] new_tokens: List[int] = [] new_scores: List[float] = [] for i in range(beam_width): score = float(nonblank_nbest_scores[i]) if score > b_nbest_score: a_hypo_idx = int(nonblank_nbest_hypo_idx[i]) base_hypos.append(a_hypos[a_hypo_idx]) new_tokens.append(int(nonblank_nbest_token[i])) new_scores.append(score) if base_hypos: new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device) else: new_hypos: List[Hypothesis] = [] return new_hypos def _gen_new_hypos( self, base_hypos: List[Hypothesis], tokens: List[int], scores: List[float], t: int, device: torch.device, ) -> List[Hypothesis]: tgt_tokens = torch.tensor([[token] for token in tokens], device=device) states = _batch_state(base_hypos) pred_out, _, pred_states = self.model.predict( tgt_tokens, torch.tensor([1] * len(base_hypos), device=device), states, ) new_hypos: List[Hypothesis] = [] for i, h_a in enumerate(base_hypos): new_tokens = h_a.tokens + [tokens[i]] new_hypos.append( Hypothesis( tokens=new_tokens, predictor_out=pred_out[i].detach(), state=_slice_state(pred_states, i, device), score=scores[i], alignment=h_a.alignment + [t], blank=self.blank, key=str(new_tokens), ) ) return new_hypos def _search( self, enc_out: torch.Tensor, hypo: Optional[Hypothesis], beam_width: int, ) -> List[Hypothesis]: n_time_steps = enc_out.shape[1] device = enc_out.device a_hypos: List[Hypothesis] = [] b_hypos = self._init_b_hypos(hypo, device) for t in range(n_time_steps): a_hypos = b_hypos b_hypos = torch.jit.annotate(List[Hypothesis], []) key_to_b_hypo: Dict[str, Hypothesis] = {} symbols_current_t = 0 while a_hypos: next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device) next_token_probs = next_token_probs.cpu() b_hypos = self._gen_b_hypos( b_hypos, a_hypos, next_token_probs, key_to_b_hypo, ) if symbols_current_t == self.step_max_tokens: break a_hypos = self._gen_a_hypos( a_hypos, b_hypos, next_token_probs, t, beam_width, device, ) if a_hypos: symbols_current_t += 1 _, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width) b_hypos = [b_hypos[idx] for idx in sorted_idx] return b_hypos
[docs] def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]: r"""Performs beam search for the given input sequence. T: number of frames; D: feature dimension of each frame. Args: input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D). length (torch.Tensor): number of valid frames in input sequence, with shape () or (1,). beam_width (int): beam size to use during search. Returns: List[Hypothesis]: top-``beam_width`` hypotheses found by beam search. """ assert input.dim() == 2 or ( input.dim() == 3 and input.shape[0] == 1 ), "input must be of shape (T, D) or (1, T, D)" if input.dim() == 2: input = input.unsqueeze(0) assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)" if input.dim() == 0: input = input.unsqueeze(0) enc_out, _ = self.model.transcribe(input, length) return self._search(enc_out, None, beam_width)
[docs] @torch.jit.export def infer( self, input: torch.Tensor, length: torch.Tensor, beam_width: int, state: Optional[List[List[torch.Tensor]]] = None, hypothesis: Optional[Hypothesis] = None, ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]: r"""Performs beam search for the given input sequence in streaming mode. T: number of frames; D: feature dimension of each frame. Args: input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D). length (torch.Tensor): number of valid frames in input sequence, with shape () or (1,). beam_width (int): beam size to use during search. state (List[List[torch.Tensor]] or None, optional): list of lists of tensors representing transcription network internal state generated in preceding invocation. (Default: ``None``) hypothesis (Hypothesis or None): hypothesis from preceding invocation to seed search with. (Default: ``None``) Returns: (List[Hypothesis], List[List[torch.Tensor]]): List[Hypothesis] top-``beam_width`` hypotheses found by beam search. List[List[torch.Tensor]] list of lists of tensors representing transcription network internal state generated in current invocation. """ assert input.dim() == 2 or ( input.dim() == 3 and input.shape[0] == 1 ), "input must be of shape (T, D) or (1, T, D)" if input.dim() == 2: input = input.unsqueeze(0) assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)" if input.dim() == 0: input = input.unsqueeze(0) enc_out, _, state = self.model.transcribe_streaming(input, length, state) return self._search(enc_out, hypothesis, beam_width), state

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources