Source code for torchaudio.models.decoder._cuda_ctc_decoder
from __future__ import annotations
import math
from typing import List, NamedTuple, Union
import torch
import torchaudio
torchaudio._extension._load_lib("libctc_prefix_decoder")
import torchaudio.lib.pybind11_prefixctc as cuctc
__all__ = ["CUCTCHypothesis", "CUCTCDecoder", "cuda_ctc_decoder"]
def _get_vocab_list(vocab_file):
vocab = []
with open(vocab_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip().split()
vocab.append(line[0])
return vocab
[docs]class CUCTCHypothesis(NamedTuple):
r"""Represents hypothesis generated by CUCTC beam search decoder :class:`CUCTCDecoder`."""
tokens: List[int]
"""Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
words: List[str]
"""List of predicted tokens. Algin with modeling unit.
"""
score: float
"""Score corresponding to hypothesis"""
_DEFAULT_BLANK_SKIP_THREASHOLD = 0.95
[docs]class CUCTCDecoder:
"""CUDA CTC beam search decoder.
.. devices:: CUDA
Note:
To build the decoder, please use the factory function :func:`cuda_ctc_decoder`.
"""
def __init__(
self,
vocab_list: List[str],
blank_id: int = 0,
beam_size: int = 10,
nbest: int = 1,
blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
cuda_stream: torch.cuda.streams.Stream = None,
):
"""
Args:
blank_id (int): token id corresopnding to blank, only support 0 for now. (Default: 0)
vocab_list (List[str]): list of vocabulary tokens
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 10)
nbest (int): number of best decodings to return
blank_skip_threshold (float):
skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding.
(Default: 0.95).
cuda_stream (torch.cuda.streams.Stream): using assigned cuda stream (Default: using default stream)
"""
if cuda_stream:
if not isinstance(cuda_stream, torch.cuda.streams.Stream):
raise AssertionError("cuda_stream must be torch.cuda.streams.Stream")
cuda_stream_ = cuda_stream.cuda_stream if cuda_stream else torch.cuda.current_stream().cuda_stream
self.internal_data = cuctc.prefixCTC_alloc(cuda_stream_)
self.memory = torch.empty(0, dtype=torch.int8, device=torch.device("cuda"))
if blank_id != 0:
raise AssertionError("blank_id must be 0")
self.blank_id = blank_id
self.vocab_list = vocab_list
self.space_id = 0
self.nbest = nbest
if not (blank_skip_threshold >= 0 and blank_skip_threshold <= 1):
raise AssertionError("blank_skip_threshold must be between 0 and 1")
self.blank_skip_threshold = math.log(blank_skip_threshold)
self.beam_size = min(beam_size, len(vocab_list)) # beam size must be smaller than vocab size
def __del__(self):
if cuctc is not None:
cuctc.prefixCTC_free(self.internal_data)
[docs] def __call__(self, log_prob: torch.Tensor, encoder_out_lens: torch.Tensor):
"""
Args:
log_prob (torch.FloatTensor): GPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; log_softmax(output of acoustic model).
lengths (dtype torch.int32): GPU tensor of shape `(batch, )` storing the valid length of
in time axis of the output Tensor in each batch.
Returns:
List[List[CUCTCHypothesis]]:
List of sorted best hypotheses for each audio sequence in the batch.
"""
if not encoder_out_lens.dtype == torch.int32:
raise AssertionError("encoder_out_lens must be torch.int32")
if not log_prob.dtype == torch.float32:
raise AssertionError("log_prob must be torch.float32")
if not (log_prob.is_cuda and encoder_out_lens.is_cuda):
raise AssertionError("inputs must be cuda tensors")
if not (log_prob.is_contiguous() and encoder_out_lens.is_contiguous()):
raise AssertionError("input tensors must be contiguous")
required_size, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
self.internal_data,
self.memory.data_ptr(),
self.memory.size(0),
log_prob.data_ptr(),
encoder_out_lens.data_ptr(),
log_prob.size(),
log_prob.stride(),
self.beam_size,
self.blank_id,
self.space_id,
self.blank_skip_threshold,
)
if required_size > 0:
self.memory = torch.empty(required_size, dtype=torch.int8, device=log_prob.device).contiguous()
_, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
self.internal_data,
self.memory.data_ptr(),
self.memory.size(0),
log_prob.data_ptr(),
encoder_out_lens.data_ptr(),
log_prob.size(),
log_prob.stride(),
self.beam_size,
self.blank_id,
self.space_id,
self.blank_skip_threshold,
)
batch_size = len(score_hyps)
hypos = []
for i in range(batch_size):
hypos.append(
[
CUCTCHypothesis(
tokens=score_hyps[i][j][1],
words=[self.vocab_list[word_id] for word_id in score_hyps[i][j][1]],
score=score_hyps[i][j][0],
)
for j in range(self.nbest)
]
)
return hypos
[docs]def cuda_ctc_decoder(
tokens: Union[str, List[str]],
nbest: int = 1,
beam_size: int = 10,
blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
) -> CUCTCDecoder:
"""Builds an instance of :class:`CUCTCDecoder`.
Args:
tokens (str or List[str]): File or list containing valid tokens.
If using a file, the expected format is for tokens mapping to the same index to be on the same line
beam_size (int, optional): The maximum number of hypos to hold after each decode step (Default: 10)
nbest (int): The number of best decodings to return
blank_id (int): The token ID corresopnding to the blank symbol.
blank_skip_threshold (float): skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding
(Default: 0.95).
Returns:
CUCTCDecoder: decoder
Example
>>> decoder = cuda_ctc_decoder(
>>> vocab_file="tokens.txt",
>>> blank_skip_threshold=0.95,
>>> )
>>> results = decoder(log_probs, encoder_out_lens) # List of shape (B, nbest) of Hypotheses
"""
if type(tokens) == str:
tokens = _get_vocab_list(tokens)
return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold)