.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/ctc_forced_alignment_api_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_ctc_forced_alignment_api_tutorial.py: CTC forced alignment API tutorial ================================= **Author**: `Xiaohui Zhang `__, `Moto Hira `__ The forced alignment is a process to align transcript with speech. This tutorial shows how to align transcripts to speech using :py:func:`torchaudio.functional.forced_align` which was developed along the work of `Scaling Speech Technology to 1,000+ Languages `__. :py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA implementations which are more performant than the vanilla Python implementation above, and are more accurate. It can also handle missing transcript with special ```` token. There is also a high-level API, :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`, which wraps the pre/post-processing explained in this tutorial and makes it easy to run forced-alignments. `Forced alignment for multilingual data <./forced_alignment_for_multilingual_data_tutorial.html>`__ uses this API to illustrate how to align non-English transcripts. .. GENERATED FROM PYTHON SOURCE LINES 27-29 Preparation ----------- .. GENERATED FROM PYTHON SOURCE LINES 29-36 .. code-block:: default import torch import torchaudio print(torch.__version__) print(torchaudio.__version__) .. rst-class:: sphx-glr-script-out .. code-block:: none 2.3.0 2.3.0 .. GENERATED FROM PYTHON SOURCE LINES 38-42 .. code-block:: default device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) .. rst-class:: sphx-glr-script-out .. code-block:: none cuda .. GENERATED FROM PYTHON SOURCE LINES 44-50 .. code-block:: default import IPython import matplotlib.pyplot as plt import torchaudio.functional as F .. GENERATED FROM PYTHON SOURCE LINES 51-54 First we prepare the speech data and the transcript we area going to use. .. GENERATED FROM PYTHON SOURCE LINES 54-60 .. code-block:: default SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") waveform, _ = torchaudio.load(SPEECH_FILE) TRANSCRIPT = "i had that curiosity beside me at this moment".split() .. GENERATED FROM PYTHON SOURCE LINES 61-80 Generating emissions ~~~~~~~~~~~~~~~~~~~~ :py:func:`~torchaudio.functional.forced_align` takes emission and token sequences and outputs timestaps of the tokens and their scores. Emission reperesents the frame-wise probability distribution over tokens, and it can be obtained by passing waveform to an acoustic model. Tokens are numerical expression of transcripts. There are many ways to tokenize transcripts, but here, we simply map alphabets into integer, which is how labels were constructed when the acoustice model we are going to use was trained. We will use a pre-trained Wav2Vec2 model, :py:data:`torchaudio.pipelines.MMS_FA`, to obtain emission and tokenize the transcript. .. GENERATED FROM PYTHON SOURCE LINES 80-88 .. code-block:: default bundle = torchaudio.pipelines.MMS_FA model = bundle.get_model(with_star=False).to(device) with torch.inference_mode(): emission, _ = model(waveform.to(device)) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt" to /root/.cache/torch/hub/checkpoints/model.pt 0%| | 0.00/1.18G [00:00 a b a - - b -> a b a a - b -> a b a - a b -> a a b ^^^ ^^^ .. GENERATED FROM PYTHON SOURCE LINES 191-199 Token-level alignments ~~~~~~~~~~~~~~~~~~~~~~ Next step is to resolve the repetation, so that each alignment does not depend on previous alignments. :py:func:`torchaudio.functional.merge_tokens` computes the :py:class:`~torchaudio.functional.TokenSpan` object, which represents which token from the transcript is present at what time span. .. GENERATED FROM PYTHON SOURCE LINES 202-210 .. code-block:: default token_spans = F.merge_tokens(aligned_tokens, alignment_scores) print("Token\tTime\tScore") for s in token_spans: print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Token Time Score i [ 32, 33) 1.00 h [ 35, 37) 0.96 a [ 37, 38) 1.00 d [ 41, 42) 1.00 t [ 44, 45) 1.00 h [ 45, 46) 1.00 a [ 47, 48) 1.00 t [ 50, 51) 1.00 c [ 54, 55) 1.00 u [ 58, 60) 0.98 r [ 63, 64) 1.00 i [ 65, 66) 1.00 o [ 72, 73) 1.00 s [ 79, 80) 1.00 i [ 83, 84) 1.00 t [ 85, 86) 1.00 y [ 88, 89) 1.00 b [ 93, 94) 1.00 e [ 95, 96) 1.00 s [101, 102) 1.00 i [110, 111) 1.00 d [113, 114) 1.00 e [114, 115) 0.85 m [116, 117) 1.00 e [119, 120) 1.00 a [124, 125) 1.00 t [127, 128) 1.00 t [129, 130) 1.00 h [130, 131) 1.00 i [132, 133) 1.00 s [136, 137) 1.00 m [141, 142) 1.00 o [144, 145) 1.00 m [148, 149) 1.00 e [151, 152) 1.00 n [153, 154) 1.00 t [155, 156) 1.00 .. GENERATED FROM PYTHON SOURCE LINES 211-215 Word-level alignments ~~~~~~~~~~~~~~~~~~~~~ Now let’s group the token-level alignments into word-level alignments. .. GENERATED FROM PYTHON SOURCE LINES 215-230 .. code-block:: default def unflatten(list_, lengths): assert len(list_) == sum(lengths) i = 0 ret = [] for l in lengths: ret.append(list_[i : i + l]) i += l return ret word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT]) .. GENERATED FROM PYTHON SOURCE LINES 231-234 Audio previews ~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 234-251 .. code-block:: default # Compute average score weighted by the span length def _score(spans): return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans) def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate): ratio = waveform.size(1) / num_frames x0 = int(ratio * spans[0].start) x1 = int(ratio * spans[-1].end) print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec") segment = waveform[:, x0:x1] return IPython.display.Audio(segment.numpy(), rate=sample_rate) num_frames = emission.size(1) .. GENERATED FROM PYTHON SOURCE LINES 252-257 .. code-block:: default # Generate the audio for each segment print(TRANSCRIPT) IPython.display.Audio(SPEECH_FILE) .. rst-class:: sphx-glr-script-out .. code-block:: none ['i', 'had', 'that', 'curiosity', 'beside', 'me', 'at', 'this', 'moment'] .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 259-262 .. code-block:: default preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none i (1.00): 0.644 - 0.664 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 264-267 .. code-block:: default preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT[1]) .. rst-class:: sphx-glr-script-out .. code-block:: none had (0.98): 0.704 - 0.845 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 269-272 .. code-block:: default preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT[2]) .. rst-class:: sphx-glr-script-out .. code-block:: none that (1.00): 0.885 - 1.026 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 274-277 .. code-block:: default preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT[3]) .. rst-class:: sphx-glr-script-out .. code-block:: none curiosity (1.00): 1.086 - 1.790 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 279-282 .. code-block:: default preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT[4]) .. rst-class:: sphx-glr-script-out .. code-block:: none beside (0.97): 1.871 - 2.314 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 284-287 .. code-block:: default preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT[5]) .. rst-class:: sphx-glr-script-out .. code-block:: none me (1.00): 2.334 - 2.414 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 289-292 .. code-block:: default preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT[6]) .. rst-class:: sphx-glr-script-out .. code-block:: none at (1.00): 2.495 - 2.575 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 294-297 .. code-block:: default preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7]) .. rst-class:: sphx-glr-script-out .. code-block:: none this (1.00): 2.595 - 2.756 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 299-302 .. code-block:: default preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8]) .. rst-class:: sphx-glr-script-out .. code-block:: none moment (1.00): 2.837 - 3.138 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 303-308 Visualization ~~~~~~~~~~~~~ Now let's look at the alignment result and segment the original speech into words. .. GENERATED FROM PYTHON SOURCE LINES 308-334 .. code-block:: default def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate): ratio = waveform.size(1) / emission.size(1) / sample_rate fig, axes = plt.subplots(2, 1) axes[0].imshow(emission[0].detach().cpu().T, aspect="auto") axes[0].set_title("Emission") axes[0].set_xticks([]) axes[1].specgram(waveform[0], Fs=sample_rate) for t_spans, chars in zip(token_spans, transcript): t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1 axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white") axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white") axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False) for span, char in zip(t_spans, chars): t0 = span.start * ratio axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False) axes[1].set_xlabel("time [second]") axes[1].set_xlim([0, None]) fig.tight_layout() .. GENERATED FROM PYTHON SOURCE LINES 336-339 .. code-block:: default plot_alignments(waveform, word_spans, emission, TRANSCRIPT) .. image-sg:: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_002.png :alt: Emission :srcset: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 340-350 Inconsistent treatment of ``blank`` token ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ When splitting the token-level alignments into words, you will notice that some blank tokens are treated differently, and this makes the interpretation of the result somehwat ambigious. This is easy to see when we plot the scores. The following figure shows word regions and non-word regions, with the frame-level scores of non-blank tokens. .. GENERATED FROM PYTHON SOURCE LINES 351-372 .. code-block:: default def plot_scores(word_spans, scores): fig, ax = plt.subplots() span_xs, span_hs = [], [] ax.axvspan(word_spans[0][0].start - 0.05, word_spans[-1][-1].end + 0.05, facecolor="paleturquoise", edgecolor="none", zorder=-1) for t_span in word_spans: for span in t_span: for t in range(span.start, span.end): span_xs.append(t + 0.5) span_hs.append(scores[t].item()) ax.annotate(LABELS[span.token], (span.start, -0.07)) ax.axvspan(t_span[0].start - 0.05, t_span[-1].end + 0.05, facecolor="mistyrose", edgecolor="none", zorder=-1) ax.bar(span_xs, span_hs, color="lightsalmon", edgecolor="coral") ax.set_title("Frame-level scores and word segments") ax.set_ylim(-0.1, None) ax.grid(True, axis="y") ax.axhline(0, color="black") fig.tight_layout() plot_scores(word_spans, alignment_scores) .. image-sg:: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_003.png :alt: Frame-level scores and word segments :srcset: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 373-410 In this plot, the blank tokens are those highlighted area without vertical bar. You can see that there are blank tokens which are interpreted as part of a word (highlighted red), while the others (highlighted blue) are not. One reason for this is because the model was trained without a label for the word boundary. The blank tokens are treated not just as repeatation but also as silence between words. But then, a question arises. Should frames immediately after or near the end of a word be silent or repeat? In the above example, if you go back to the previous plot of spectrogram and word regions, you see that after "y" in "curiosity", there is still some activities in multiple frequency buckets. Would it be more accurate if that frame was included in the word? Unfortunately, CTC does not provide a comprehensive solution to this. Models trained with CTC are known to exhibit "peaky" response, that is, they tend to spike for an aoccurance of a label, but the spike does not last for the duration of the label. (Note: Pre-trained Wav2Vec2 models tend to spike at the beginning of label occurances, but this not always the case.) :cite:`zeyer2021does` has in-depth alanysis on the peaky behavior of CTC. We encourage those who are interested understanding more to refer to the paper. The following is a quote from the paper, which is the exact issue we are facing here. *Peaky behavior can be problematic in certain cases,* *e.g. when an application requires to not use the blank label,* *e.g. to get meaningful time accurate alignments of phonemes* *to a transcription.* .. GENERATED FROM PYTHON SOURCE LINES 412-425 Advanced: Handling transcripts with ```` token ---------------------------------------------------- Now let’s look at when the transcript is partially missing, how can we improve alignment quality using the ```` token, which is capable of modeling any token. Here we use the same English example as used above. But we remove the beginning text ``“i had that curiosity beside me at”`` from the transcript. Aligning audio with such transcript results in wrong alignments of the existing word “this”. However, this issue can be mitigated by using the ```` token to model the missing text. .. GENERATED FROM PYTHON SOURCE LINES 427-428 First, we extend the dictionary to include the ```` token. .. GENERATED FROM PYTHON SOURCE LINES 428-431 .. code-block:: default DICTIONARY["*"] = len(DICTIONARY) .. GENERATED FROM PYTHON SOURCE LINES 432-435 Next, we extend the emission tensor with the extra dimension corresponding to the ```` token. .. GENERATED FROM PYTHON SOURCE LINES 435-443 .. code-block:: default star_dim = torch.zeros((1, emission.size(1), 1), device=emission.device, dtype=emission.dtype) emission = torch.cat((emission, star_dim), 2) assert len(DICTIONARY) == emission.shape[2] plot_emission(emission[0]) .. image-sg:: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_004.png :alt: Frame-wise class probabilities :srcset: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 444-446 The following function combines all the processes, and compute word segments from emission in one-go. .. GENERATED FROM PYTHON SOURCE LINES 446-456 .. code-block:: default def compute_alignments(emission, transcript, dictionary): tokens = [dictionary[char] for word in transcript for char in word] alignment, scores = align(emission, tokens) token_spans = F.merge_tokens(alignment, scores) word_spans = unflatten(token_spans, [len(word) for word in transcript]) return word_spans .. GENERATED FROM PYTHON SOURCE LINES 457-459 Full Transcript ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 459-463 .. code-block:: default word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY) plot_alignments(waveform, word_spans, emission, TRANSCRIPT) .. image-sg:: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_005.png :alt: Emission :srcset: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_005.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 464-468 Partial Transcript with ```` token ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Now we replace the first part of the transcript with the ```` token. .. GENERATED FROM PYTHON SOURCE LINES 468-473 .. code-block:: default transcript = "* this moment".split() word_spans = compute_alignments(emission, transcript, DICTIONARY) plot_alignments(waveform, word_spans, emission, transcript) .. image-sg:: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_006.png :alt: Emission :srcset: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_006.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 475-478 .. code-block:: default preview_word(waveform, word_spans[0], num_frames, transcript[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none * (1.00): 0.000 - 2.595 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 480-483 .. code-block:: default preview_word(waveform, word_spans[1], num_frames, transcript[1]) .. rst-class:: sphx-glr-script-out .. code-block:: none this (1.00): 2.595 - 2.756 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 485-488 .. code-block:: default preview_word(waveform, word_spans[2], num_frames, transcript[2]) .. rst-class:: sphx-glr-script-out .. code-block:: none moment (1.00): 2.837 - 3.138 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 489-495 Partial Transcript without ```` token ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ As a comparison, the following aligns the partial transcript without using ```` token. It demonstrates the effect of ```` token for dealing with deletion errors. .. GENERATED FROM PYTHON SOURCE LINES 495-500 .. code-block:: default transcript = "this moment".split() word_spans = compute_alignments(emission, transcript, DICTIONARY) plot_alignments(waveform, word_spans, emission, transcript) .. image-sg:: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_007.png :alt: Emission :srcset: /tutorials/images/sphx_glr_ctc_forced_alignment_api_tutorial_007.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 501-509 Conclusion ---------- In this tutorial, we looked at how to use torchaudio’s forced alignment API to align and segment speech files, and demonstrated one advanced usage: How introducing a ```` token could improve alignment accuracy when transcription errors exist. .. GENERATED FROM PYTHON SOURCE LINES 512-518 Acknowledgement --------------- Thanks to `Vineel Pratap `__ and `Zhaoheng Ni `__ for developing and open-sourcing the forced aligner API. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 7.789 seconds) .. _sphx_glr_download_tutorials_ctc_forced_alignment_api_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: ctc_forced_alignment_api_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: ctc_forced_alignment_api_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_