• Docs >
  • Module code >
  • torchaudio.models.decoder._cuda_ctc_decoder >
  • Nightly (unstable)
Shortcuts

Source code for torchaudio.models.decoder._cuda_ctc_decoder

from __future__ import annotations

import math

from typing import List, NamedTuple, Union

import torch
import torchaudio

torchaudio._extension._load_lib("libctc_prefix_decoder")
import torchaudio.lib.pybind11_prefixctc as cuctc


__all__ = ["CUCTCHypothesis", "CUCTCDecoder", "cuda_ctc_decoder"]


def _get_vocab_list(vocab_file):
    vocab = []
    with open(vocab_file, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip().split()
            vocab.append(line[0])
    return vocab


[docs]class CUCTCHypothesis(NamedTuple): r"""Represents hypothesis generated by CUCTC beam search decoder :class:`CUCTCDecoder`.""" tokens: List[int] """Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence""" words: List[str] """List of predicted tokens. Algin with modeling unit. """ score: float """Score corresponding to hypothesis"""
_DEFAULT_BLANK_SKIP_THREASHOLD = 0.95
[docs]class CUCTCDecoder: """CUDA CTC beam search decoder. .. devices:: CUDA Note: To build the decoder, please use the factory function :func:`cuda_ctc_decoder`. """ def __init__( self, vocab_list: List[str], blank_id: int = 0, beam_size: int = 10, nbest: int = 1, blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD, cuda_stream: torch.cuda.streams.Stream = None, ): """ Args: blank_id (int): token id corresopnding to blank, only support 0 for now. (Default: 0) vocab_list (List[str]): list of vocabulary tokens beam_size (int, optional): max number of hypos to hold after each decode step (Default: 10) nbest (int): number of best decodings to return blank_skip_threshold (float): skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding. (Default: 0.95). cuda_stream (torch.cuda.streams.Stream): using assigned cuda stream (Default: using default stream) """ if cuda_stream: if not isinstance(cuda_stream, torch.cuda.streams.Stream): raise AssertionError("cuda_stream must be torch.cuda.streams.Stream") cuda_stream_ = cuda_stream.cuda_stream if cuda_stream else torch.cuda.current_stream().cuda_stream self.internal_data = cuctc.prefixCTC_alloc(cuda_stream_) self.memory = torch.empty(0, dtype=torch.int8, device=torch.device("cuda")) if blank_id != 0: raise AssertionError("blank_id must be 0") self.blank_id = blank_id self.vocab_list = vocab_list self.space_id = 0 self.nbest = nbest if not (blank_skip_threshold >= 0 and blank_skip_threshold <= 1): raise AssertionError("blank_skip_threshold must be between 0 and 1") self.blank_skip_threshold = math.log(blank_skip_threshold) self.beam_size = min(beam_size, len(vocab_list)) # beam size must be smaller than vocab size def __del__(self): if cuctc is not None: cuctc.prefixCTC_free(self.internal_data)
[docs] def __call__(self, log_prob: torch.Tensor, encoder_out_lens: torch.Tensor): """ Args: log_prob (torch.FloatTensor): GPU tensor of shape `(batch, frame, num_tokens)` storing sequences of probability distribution over labels; log_softmax(output of acoustic model). lengths (dtype torch.int32): GPU tensor of shape `(batch, )` storing the valid length of in time axis of the output Tensor in each batch. Returns: List[List[CUCTCHypothesis]]: List of sorted best hypotheses for each audio sequence in the batch. """ if not encoder_out_lens.dtype == torch.int32: raise AssertionError("encoder_out_lens must be torch.int32") if not log_prob.dtype == torch.float32: raise AssertionError("log_prob must be torch.float32") if not (log_prob.is_cuda and encoder_out_lens.is_cuda): raise AssertionError("inputs must be cuda tensors") if not (log_prob.is_contiguous() and encoder_out_lens.is_contiguous()): raise AssertionError("input tensors must be contiguous") required_size, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2( self.internal_data, self.memory.data_ptr(), self.memory.size(0), log_prob.data_ptr(), encoder_out_lens.data_ptr(), log_prob.size(), log_prob.stride(), self.beam_size, self.blank_id, self.space_id, self.blank_skip_threshold, ) if required_size > 0: self.memory = torch.empty(required_size, dtype=torch.int8, device=log_prob.device).contiguous() _, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2( self.internal_data, self.memory.data_ptr(), self.memory.size(0), log_prob.data_ptr(), encoder_out_lens.data_ptr(), log_prob.size(), log_prob.stride(), self.beam_size, self.blank_id, self.space_id, self.blank_skip_threshold, ) batch_size = len(score_hyps) hypos = [] for i in range(batch_size): hypos.append( [ CUCTCHypothesis( tokens=score_hyps[i][j][1], words=[self.vocab_list[word_id] for word_id in score_hyps[i][j][1]], score=score_hyps[i][j][0], ) for j in range(self.nbest) ] ) return hypos
[docs]def cuda_ctc_decoder( tokens: Union[str, List[str]], nbest: int = 1, beam_size: int = 10, blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD, ) -> CUCTCDecoder: """Builds an instance of :class:`CUCTCDecoder`. Args: tokens (str or List[str]): File or list containing valid tokens. If using a file, the expected format is for tokens mapping to the same index to be on the same line beam_size (int, optional): The maximum number of hypos to hold after each decode step (Default: 10) nbest (int): The number of best decodings to return blank_id (int): The token ID corresopnding to the blank symbol. blank_skip_threshold (float): skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding (Default: 0.95). Returns: CUCTCDecoder: decoder Example >>> decoder = cuda_ctc_decoder( >>> vocab_file="tokens.txt", >>> blank_skip_threshold=0.95, >>> ) >>> results = decoder(log_probs, encoder_out_lens) # List of shape (B, nbest) of Hypotheses """ if type(tokens) == str: tokens = _get_vocab_list(tokens) return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources