.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/forced_alignment_for_multilingual_data_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_forced_alignment_for_multilingual_data_tutorial.py: Forced alignment for multilingual data ====================================== **Authors**: `Xiaohui Zhang `__, `Moto Hira `__. This tutorial shows how to align transcript to speech for non-English languages. The process of aligning non-English (normalized) transcript is identical to aligning English (normalized) transcript, and the process for English is covered in detail in `CTC forced alignment tutorial <./ctc_forced_alignment_api_tutorial.html>`__. In this tutorial, we use TorchAudio's high-level API, :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`, which packages the pre-trained model, tokenizer and aligner, to perform the forced alignment with less code. .. GENERATED FROM PYTHON SOURCE LINES 16-27 .. code-block:: default import torch import torchaudio print(torch.__version__) print(torchaudio.__version__) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) .. rst-class:: sphx-glr-script-out .. code-block:: none 2.3.0 2.3.0 cuda .. GENERATED FROM PYTHON SOURCE LINES 29-34 .. code-block:: default from typing import List import IPython import matplotlib.pyplot as plt .. GENERATED FROM PYTHON SOURCE LINES 35-67 Creating the pipeline --------------------- First, we instantiate the model and pre/post-processing pipelines. The following diagram illustrates the process of alignment. .. image:: https://download.pytorch.org/torchaudio/doc-assets/pipelines-wav2vec2fabundle.png The waveform is passed to an acoustic model, which produces the sequence of probability distribution of tokens. The transcript is passed to tokenizer, which converts the transcript to sequence of tokens. Aligner takes the results from the acoustic model and the tokenizer and generate timestamps for each token. .. note:: This process expects that the input transcript is already normalized. The process of normalization, which involves romanization of non-English languages, is language-dependent, so it is not covered in this tutorial, but we will breifly look into it. The acoustic model and the tokenizer must use the same set of tokens. To facilitate the creation of matching processors, :py:class:`~torchaudio.pipelines.Wav2Vec2FABundle` associates a pre-trained accoustic model and a tokenizer. :py:data:`torchaudio.pipelines.MMS_FA` is one of such instance. The following code instantiates a pre-trained acoustic model, a tokenizer which uses the same set of tokens as the model, and an aligner. .. GENERATED FROM PYTHON SOURCE LINES 67-76 .. code-block:: default from torchaudio.pipelines import MMS_FA as bundle model = bundle.get_model() model.to(device) tokenizer = bundle.get_tokenizer() aligner = bundle.get_aligner() .. GENERATED FROM PYTHON SOURCE LINES 77-84 .. note:: The model instantiated by :py:data:`~torchaudio.pipelines.MMS_FA`'s :py:meth:`~torchaudio.pipelines.Wav2Vec2FABundle.get_model` method by default includes the feature dimension for ```` token. You can disable this by passing ``with_star=False``. .. GENERATED FROM PYTHON SOURCE LINES 86-94 The acoustic model of :py:data:`~torchaudio.pipelines.MMS_FA` was created and open-sourced as part of the research project, `Scaling Speech Technology to 1,000+ Languages `__. It was trained with 23,000 hours of audio from 1100+ languages. The tokenizer simply maps the normalized characters to integers. You can check the mapping as follow; .. GENERATED FROM PYTHON SOURCE LINES 94-98 .. code-block:: default print(bundle.get_dict()) .. rst-class:: sphx-glr-script-out .. code-block:: none {'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27, '*': 28} .. GENERATED FROM PYTHON SOURCE LINES 99-106 The aligner internally uses :py:func:`torchaudio.functional.forced_align` and :py:func:`torchaudio.functional.merge_tokens` to infer the time stamps of the input tokens. The detail of the underlying mechanism is covered in `CTC forced alignment API tutorial <./ctc_forced_alignment_api_tutorial.html>`__, so please refer to it. .. GENERATED FROM PYTHON SOURCE LINES 109-112 We define a utility function that performs the forced alignment with the above model, the tokenizer and the aligner. .. GENERATED FROM PYTHON SOURCE LINES 112-119 .. code-block:: default def compute_alignments(waveform: torch.Tensor, transcript: List[str]): with torch.inference_mode(): emission, _ = model(waveform.to(device)) token_spans = aligner(emission[0], tokenizer(transcript)) return emission, token_spans .. GENERATED FROM PYTHON SOURCE LINES 120-122 We also define utility functions for plotting the result and previewing the audio segments. .. GENERATED FROM PYTHON SOURCE LINES 122-151 .. code-block:: default # Compute average score weighted by the span length def _score(spans): return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans) def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate): ratio = waveform.size(1) / emission.size(1) / sample_rate fig, axes = plt.subplots(2, 1) axes[0].imshow(emission[0].detach().cpu().T, aspect="auto") axes[0].set_title("Emission") axes[0].set_xticks([]) axes[1].specgram(waveform[0], Fs=sample_rate) for t_spans, chars in zip(token_spans, transcript): t0, t1 = t_spans[0].start, t_spans[-1].end axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white") axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white") axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False) for span, char in zip(t_spans, chars): t0 = span.start * ratio axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False) axes[1].set_xlabel("time [second]") fig.tight_layout() .. GENERATED FROM PYTHON SOURCE LINES 153-162 .. code-block:: default def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate): ratio = waveform.size(1) / num_frames x0 = int(ratio * spans[0].start) x1 = int(ratio * spans[-1].end) print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec") segment = waveform[:, x0:x1] return IPython.display.Audio(segment.numpy(), rate=sample_rate) .. GENERATED FROM PYTHON SOURCE LINES 163-222 Normalizing the transcript -------------------------- The transcripts passed to the pipeline must be normalized beforehand. The exact process of normalization depends on language. Languages that do not have explicit word boundaries (such as Chinese, Japanese and Korean) require segmentation first. There are dedicated tools for this, but let's say we have segmented transcript. The first step of normalization is romanization. `uroman `__ is a tool that supports many languages. Here is a BASH commands to romanize the input text file and write the output to another text file using ``uroman``. .. code-block:: bash $ echo "des événements d'actualité qui se sont produits durant l'année 1882" > text.txt $ uroman/bin/uroman.pl < text.txt > text_romanized.txt $ cat text_romanized.txt .. code-block:: text Cette page concerne des evenements d'actualite qui se sont produits durant l'annee 1882 The next step is to remove non-alphabets and punctuations. The following snippet normalizes the romanized transcript. .. code-block:: python import re def normalize_uroman(text): text = text.lower() text = text.replace("’", "'") text = re.sub("([^a-z' ])", " ", text) text = re.sub(' +', ' ', text) return text.strip() with open("text_romanized.txt", "r") as f: for line in f: text_normalized = normalize_uroman(line) print(text_normalized) Running the script on the above exanple produces the following. .. code-block:: text cette page concerne des evenements d'actualite qui se sont produits durant l'annee Note that, in this example, since "1882" was not romanized by ``uroman``, it was removed in the normalization step. To avoid this, one needs to romanize numbers, but this is known to be a non-trivial task. .. GENERATED FROM PYTHON SOURCE LINES 224-232 Aligning transcripts to speech ------------------------------ Now we perform the forced alignment for multiple languages. German ~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 232-241 .. code-block:: default text_raw = "aber seit ich bei ihnen das brot hole" text_normalized = "aber seit ich bei ihnen das brot hole" url = "https://download.pytorch.org/torchaudio/tutorial-assets/10349_8674_000087.flac" waveform, sample_rate = torchaudio.load( url, frame_offset=int(0.5 * bundle.sample_rate), num_frames=int(2.5 * bundle.sample_rate) ) .. GENERATED FROM PYTHON SOURCE LINES 243-245 .. code-block:: default assert sample_rate == bundle.sample_rate .. GENERATED FROM PYTHON SOURCE LINES 247-260 .. code-block:: default transcript = text_normalized.split() tokens = tokenizer(transcript) emission, token_spans = compute_alignments(waveform, transcript) num_frames = emission.size(1) plot_alignments(waveform, token_spans, emission, transcript) print("Raw Transcript: ", text_raw) print("Normalized Transcript: ", text_normalized) IPython.display.Audio(waveform, rate=sample_rate) .. image-sg:: /tutorials/images/sphx_glr_forced_alignment_for_multilingual_data_tutorial_001.png :alt: Emission :srcset: /tutorials/images/sphx_glr_forced_alignment_for_multilingual_data_tutorial_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/modules/conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608839953/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.) return F.conv1d(input, weight, bias, self.stride, Raw Transcript: aber seit ich bei ihnen das brot hole Normalized Transcript: aber seit ich bei ihnen das brot hole .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 262-265 .. code-block:: default preview_word(waveform, token_spans[0], num_frames, transcript[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none aber (0.96): 0.222 - 0.464 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 267-270 .. code-block:: default preview_word(waveform, token_spans[1], num_frames, transcript[1]) .. rst-class:: sphx-glr-script-out .. code-block:: none seit (0.78): 0.565 - 0.766 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 272-275 .. code-block:: default preview_word(waveform, token_spans[2], num_frames, transcript[2]) .. rst-class:: sphx-glr-script-out .. code-block:: none ich (0.91): 0.847 - 0.948 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 277-280 .. code-block:: default preview_word(waveform, token_spans[3], num_frames, transcript[3]) .. rst-class:: sphx-glr-script-out .. code-block:: none bei (0.96): 1.028 - 1.190 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 282-285 .. code-block:: default preview_word(waveform, token_spans[4], num_frames, transcript[4]) .. rst-class:: sphx-glr-script-out .. code-block:: none ihnen (0.65): 1.331 - 1.532 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 287-290 .. code-block:: default preview_word(waveform, token_spans[5], num_frames, transcript[5]) .. rst-class:: sphx-glr-script-out .. code-block:: none das (0.54): 1.573 - 1.774 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 292-295 .. code-block:: default preview_word(waveform, token_spans[6], num_frames, transcript[6]) .. rst-class:: sphx-glr-script-out .. code-block:: none brot (0.86): 1.855 - 2.117 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 297-300 .. code-block:: default preview_word(waveform, token_spans[7], num_frames, transcript[7]) .. rst-class:: sphx-glr-script-out .. code-block:: none hole (0.71): 2.177 - 2.480 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 301-311 Chinese ~~~~~~~ Chinese is a character-based language, and there is not explicit word-level tokenization (separated by spaces) in its raw written form. In order to obtain word level alignments, you need to first tokenize the transcripts at the word level using a word tokenizer like `“Stanford Tokenizer” `__. However this is not needed if you only want character-level alignments. .. GENERATED FROM PYTHON SOURCE LINES 311-315 .. code-block:: default text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面" text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian" .. GENERATED FROM PYTHON SOURCE LINES 317-322 .. code-block:: default url = "https://download.pytorch.org/torchaudio/tutorial-assets/mvdr/clean_speech.wav" waveform, sample_rate = torchaudio.load(url) waveform = waveform[0:1] .. GENERATED FROM PYTHON SOURCE LINES 324-326 .. code-block:: default assert sample_rate == bundle.sample_rate .. GENERATED FROM PYTHON SOURCE LINES 328-339 .. code-block:: default transcript = text_normalized.split() emission, token_spans = compute_alignments(waveform, transcript) num_frames = emission.size(1) plot_alignments(waveform, token_spans, emission, transcript) print("Raw Transcript: ", text_raw) print("Normalized Transcript: ", text_normalized) IPython.display.Audio(waveform, rate=sample_rate) .. image-sg:: /tutorials/images/sphx_glr_forced_alignment_for_multilingual_data_tutorial_002.png :alt: Emission :srcset: /tutorials/images/sphx_glr_forced_alignment_for_multilingual_data_tutorial_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/modules/conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608839953/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.) return F.conv1d(input, weight, bias, self.stride, Raw Transcript: 关 服务 高端 产品 仍 处于 供不应求 的 局面 Normalized Transcript: guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 341-344 .. code-block:: default preview_word(waveform, token_spans[0], num_frames, transcript[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none guan (0.33): 0.020 - 0.141 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 346-349 .. code-block:: default preview_word(waveform, token_spans[1], num_frames, transcript[1]) .. rst-class:: sphx-glr-script-out .. code-block:: none fuwu (0.31): 0.221 - 0.583 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 351-354 .. code-block:: default preview_word(waveform, token_spans[2], num_frames, transcript[2]) .. rst-class:: sphx-glr-script-out .. code-block:: none gaoduan (0.74): 0.724 - 1.065 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 356-359 .. code-block:: default preview_word(waveform, token_spans[3], num_frames, transcript[3]) .. rst-class:: sphx-glr-script-out .. code-block:: none chanpin (0.73): 1.126 - 1.528 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 361-364 .. code-block:: default preview_word(waveform, token_spans[4], num_frames, transcript[4]) .. rst-class:: sphx-glr-script-out .. code-block:: none reng (0.86): 1.608 - 1.809 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 366-369 .. code-block:: default preview_word(waveform, token_spans[5], num_frames, transcript[5]) .. rst-class:: sphx-glr-script-out .. code-block:: none chuyu (0.80): 1.849 - 2.151 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 371-374 .. code-block:: default preview_word(waveform, token_spans[6], num_frames, transcript[6]) .. rst-class:: sphx-glr-script-out .. code-block:: none gongbuyingqiu (0.93): 2.251 - 2.894 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 376-379 .. code-block:: default preview_word(waveform, token_spans[7], num_frames, transcript[7]) .. rst-class:: sphx-glr-script-out .. code-block:: none de (0.98): 2.935 - 3.015 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 381-385 .. code-block:: default preview_word(waveform, token_spans[8], num_frames, transcript[8]) .. rst-class:: sphx-glr-script-out .. code-block:: none jumian (0.95): 3.075 - 3.477 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 386-388 Polish ~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 388-395 .. code-block:: default text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę" text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane" url = "https://download.pytorch.org/torchaudio/tutorial-assets/5090_1447_000088.flac" waveform, sample_rate = torchaudio.load(url, num_frames=int(4.5 * bundle.sample_rate)) .. GENERATED FROM PYTHON SOURCE LINES 397-399 .. code-block:: default assert sample_rate == bundle.sample_rate .. GENERATED FROM PYTHON SOURCE LINES 401-412 .. code-block:: default transcript = text_normalized.split() emission, token_spans = compute_alignments(waveform, transcript) num_frames = emission.size(1) plot_alignments(waveform, token_spans, emission, transcript) print("Raw Transcript: ", text_raw) print("Normalized Transcript: ", text_normalized) IPython.display.Audio(waveform, rate=sample_rate) .. image-sg:: /tutorials/images/sphx_glr_forced_alignment_for_multilingual_data_tutorial_003.png :alt: Emission :srcset: /tutorials/images/sphx_glr_forced_alignment_for_multilingual_data_tutorial_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/modules/conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608839953/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.) return F.conv1d(input, weight, bias, self.stride, Raw Transcript: wtedy ujrzałem na jego brzuchu okrągłą czarną ranę Normalized Transcript: wtedy ujrzalem na jego brzuchu okragla czarna rane .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 414-417 .. code-block:: default preview_word(waveform, token_spans[0], num_frames, transcript[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none wtedy (1.00): 0.783 - 1.145 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 419-422 .. code-block:: default preview_word(waveform, token_spans[1], num_frames, transcript[1]) .. rst-class:: sphx-glr-script-out .. code-block:: none ujrzalem (0.96): 1.286 - 1.788 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 424-427 .. code-block:: default preview_word(waveform, token_spans[2], num_frames, transcript[2]) .. rst-class:: sphx-glr-script-out .. code-block:: none na (1.00): 1.868 - 1.949 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 429-432 .. code-block:: default preview_word(waveform, token_spans[3], num_frames, transcript[3]) .. rst-class:: sphx-glr-script-out .. code-block:: none jego (1.00): 2.009 - 2.230 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 434-437 .. code-block:: default preview_word(waveform, token_spans[4], num_frames, transcript[4]) .. rst-class:: sphx-glr-script-out .. code-block:: none brzuchu (0.97): 2.330 - 2.732 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 439-442 .. code-block:: default preview_word(waveform, token_spans[5], num_frames, transcript[5]) .. rst-class:: sphx-glr-script-out .. code-block:: none okragla (1.00): 2.893 - 3.415 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 444-447 .. code-block:: default preview_word(waveform, token_spans[6], num_frames, transcript[6]) .. rst-class:: sphx-glr-script-out .. code-block:: none czarna (0.90): 3.556 - 3.938 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 449-452 .. code-block:: default preview_word(waveform, token_spans[7], num_frames, transcript[7]) .. rst-class:: sphx-glr-script-out .. code-block:: none rane (1.00): 4.098 - 4.399 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 453-455 Portuguese ~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 455-464 .. code-block:: default text_raw = "na imensa extensão onde se esconde o inconsciente imortal" text_normalized = "na imensa extensao onde se esconde o inconsciente imortal" url = "https://download.pytorch.org/torchaudio/tutorial-assets/6566_5323_000027.flac" waveform, sample_rate = torchaudio.load( url, frame_offset=int(bundle.sample_rate), num_frames=int(4.6 * bundle.sample_rate) ) .. GENERATED FROM PYTHON SOURCE LINES 466-468 .. code-block:: default assert sample_rate == bundle.sample_rate .. GENERATED FROM PYTHON SOURCE LINES 470-481 .. code-block:: default transcript = text_normalized.split() emission, token_spans = compute_alignments(waveform, transcript) num_frames = emission.size(1) plot_alignments(waveform, token_spans, emission, transcript) print("Raw Transcript: ", text_raw) print("Normalized Transcript: ", text_normalized) IPython.display.Audio(waveform, rate=sample_rate) .. image-sg:: /tutorials/images/sphx_glr_forced_alignment_for_multilingual_data_tutorial_004.png :alt: Emission :srcset: /tutorials/images/sphx_glr_forced_alignment_for_multilingual_data_tutorial_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/modules/conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608839953/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.) return F.conv1d(input, weight, bias, self.stride, Raw Transcript: na imensa extensão onde se esconde o inconsciente imortal Normalized Transcript: na imensa extensao onde se esconde o inconsciente imortal .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 483-486 .. code-block:: default preview_word(waveform, token_spans[0], num_frames, transcript[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none na (1.00): 0.020 - 0.080 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 488-491 .. code-block:: default preview_word(waveform, token_spans[1], num_frames, transcript[1]) .. rst-class:: sphx-glr-script-out .. code-block:: none imensa (0.90): 0.120 - 0.502 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 493-496 .. code-block:: default preview_word(waveform, token_spans[2], num_frames, transcript[2]) .. rst-class:: sphx-glr-script-out .. code-block:: none extensao (0.92): 0.542 - 1.205 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 498-501 .. code-block:: default preview_word(waveform, token_spans[3], num_frames, transcript[3]) .. rst-class:: sphx-glr-script-out .. code-block:: none onde (1.00): 1.446 - 1.667 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 503-506 .. code-block:: default preview_word(waveform, token_spans[4], num_frames, transcript[4]) .. rst-class:: sphx-glr-script-out .. code-block:: none se (0.99): 1.748 - 1.828 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 508-511 .. code-block:: default preview_word(waveform, token_spans[5], num_frames, transcript[5]) .. rst-class:: sphx-glr-script-out .. code-block:: none esconde (0.99): 1.888 - 2.591 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 513-516 .. code-block:: default preview_word(waveform, token_spans[6], num_frames, transcript[6]) .. rst-class:: sphx-glr-script-out .. code-block:: none o (0.98): 2.852 - 2.872 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 518-521 .. code-block:: default preview_word(waveform, token_spans[7], num_frames, transcript[7]) .. rst-class:: sphx-glr-script-out .. code-block:: none inconsciente (0.80): 2.933 - 3.897 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 523-527 .. code-block:: default preview_word(waveform, token_spans[8], num_frames, transcript[8]) .. rst-class:: sphx-glr-script-out .. code-block:: none imortal (0.86): 3.937 - 4.560 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 528-530 Italian ~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 530-537 .. code-block:: default text_raw = "elle giacean per terra tutte quante" text_normalized = "elle giacean per terra tutte quante" url = "https://download.pytorch.org/torchaudio/tutorial-assets/642_529_000025.flac" waveform, sample_rate = torchaudio.load(url, num_frames=int(4 * bundle.sample_rate)) .. GENERATED FROM PYTHON SOURCE LINES 539-541 .. code-block:: default assert sample_rate == bundle.sample_rate .. GENERATED FROM PYTHON SOURCE LINES 543-554 .. code-block:: default transcript = text_normalized.split() emission, token_spans = compute_alignments(waveform, transcript) num_frames = emission.size(1) plot_alignments(waveform, token_spans, emission, transcript) print("Raw Transcript: ", text_raw) print("Normalized Transcript: ", text_normalized) IPython.display.Audio(waveform, rate=sample_rate) .. image-sg:: /tutorials/images/sphx_glr_forced_alignment_for_multilingual_data_tutorial_005.png :alt: Emission :srcset: /tutorials/images/sphx_glr_forced_alignment_for_multilingual_data_tutorial_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Raw Transcript: elle giacean per terra tutte quante Normalized Transcript: elle giacean per terra tutte quante .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 556-559 .. code-block:: default preview_word(waveform, token_spans[0], num_frames, transcript[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none elle (1.00): 0.563 - 0.864 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 561-564 .. code-block:: default preview_word(waveform, token_spans[1], num_frames, transcript[1]) .. rst-class:: sphx-glr-script-out .. code-block:: none giacean (0.99): 0.945 - 1.467 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 566-569 .. code-block:: default preview_word(waveform, token_spans[2], num_frames, transcript[2]) .. rst-class:: sphx-glr-script-out .. code-block:: none per (1.00): 1.588 - 1.789 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 571-574 .. code-block:: default preview_word(waveform, token_spans[3], num_frames, transcript[3]) .. rst-class:: sphx-glr-script-out .. code-block:: none terra (1.00): 1.950 - 2.392 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 576-579 .. code-block:: default preview_word(waveform, token_spans[4], num_frames, transcript[4]) .. rst-class:: sphx-glr-script-out .. code-block:: none tutte (1.00): 2.533 - 2.975 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 581-584 .. code-block:: default preview_word(waveform, token_spans[5], num_frames, transcript[5]) .. rst-class:: sphx-glr-script-out .. code-block:: none quante (1.00): 3.055 - 3.678 sec .. raw:: html


.. GENERATED FROM PYTHON SOURCE LINES 585-592 Conclusion ---------- In this tutorial, we looked at how to use torchaudio’s forced alignment API and a Wav2Vec2 pre-trained mulilingual acoustic model to align speech data to transcripts in five languages. .. GENERATED FROM PYTHON SOURCE LINES 594-601 Acknowledgement --------------- Thanks to `Vineel Pratap `__ and `Zhaoheng Ni `__ for developing and open-sourcing the forced aligner API. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 4.458 seconds) .. _sphx_glr_download_tutorials_forced_alignment_for_multilingual_data_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: forced_alignment_for_multilingual_data_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: forced_alignment_for_multilingual_data_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_