.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/tacotron2_pipeline_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_tacotron2_pipeline_tutorial.py: Text-to-Speech with Tacotron2 ============================= **Author** `Yao-Yuan Yang `__, `Moto Hira `__ .. GENERATED FROM PYTHON SOURCE LINES 9-14 .. code-block:: default import IPython import matplotlib import matplotlib.pyplot as plt .. GENERATED FROM PYTHON SOURCE LINES 15-50 Overview -------- This tutorial shows how to build text-to-speech pipeline, using the pretrained Tacotron2 in torchaudio. The text-to-speech pipeline goes as follows: 1. Text preprocessing First, the input text is encoded into a list of symbols. In this tutorial, we will use English characters and phonemes as the symbols. 2. Spectrogram generation From the encoded text, a spectrogram is generated. We use ``Tacotron2`` model for this. 3. Time-domain conversion The last step is converting the spectrogram into the waveform. The process to generate speech from spectrogram is also called Vocoder. In this tutorial, three different vocoders are used, `WaveRNN `__, `Griffin-Lim `__, and `Nvidia's WaveGlow `__. The following figure illustrates the whole process. .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/tacotron2_tts_pipeline.png All the related components are bundled in :py:func:`torchaudio.pipelines.Tacotron2TTSBundle`, but this tutorial will also cover the process under the hood. .. GENERATED FROM PYTHON SOURCE LINES 52-59 Preparation ----------- First, we install the necessary dependencies. In addition to ``torchaudio``, ``DeepPhonemizer`` is required to perform phoneme-based encoding. .. GENERATED FROM PYTHON SOURCE LINES 59-76 .. code-block:: default # When running this example in notebook, install DeepPhonemizer # !pip3 install deep_phonemizer import torch import torchaudio matplotlib.rcParams["figure.figsize"] = [16.0, 4.8] torch.random.manual_seed(0) device = "cuda" if torch.cuda.is_available() else "cpu" print(torch.__version__) print(torchaudio.__version__) print(device) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 1.12.0 0.12.0 cpu .. GENERATED FROM PYTHON SOURCE LINES 77-80 Text Processing --------------- .. GENERATED FROM PYTHON SOURCE LINES 83-101 Character-based encoding ~~~~~~~~~~~~~~~~~~~~~~~~ In this section, we will go through how the character-based encoding works. Since the pre-trained Tacotron2 model expects specific set of symbol tables, the same functionalities available in ``torchaudio``. This section is more for the explanation of the basis of encoding. Firstly, we define the set of symbols. For example, we can use ``'_-!\'(),.:;? abcdefghijklmnopqrstuvwxyz'``. Then, we will map the each character of the input text into the index of the corresponding symbol in the table. The following is an example of such processing. In the example, symbols that are not in the table are ignored. .. GENERATED FROM PYTHON SOURCE LINES 101-116 .. code-block:: default symbols = "_-!'(),.:;? abcdefghijklmnopqrstuvwxyz" look_up = {s: i for i, s in enumerate(symbols)} symbols = set(symbols) def text_to_sequence(text): text = text.lower() return [look_up[s] for s in text if s in symbols] text = "Hello world! Text to speech!" print(text_to_sequence(text)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2, 11, 31, 16, 35, 31, 11, 31, 26, 11, 30, 27, 16, 16, 14, 19, 2] .. GENERATED FROM PYTHON SOURCE LINES 117-122 As mentioned in the above, the symbol table and indices must match what the pretrained Tacotron2 model expects. ``torchaudio`` provides the transform along with the pretrained model. For example, you can instantiate and use such transform as follow. .. GENERATED FROM PYTHON SOURCE LINES 122-132 .. code-block:: default processor = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor() text = "Hello world! Text to speech!" processed, lengths = processor(text) print(processed) print(lengths) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none tensor([[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2, 11, 31, 16, 35, 31, 11, 31, 26, 11, 30, 27, 16, 16, 14, 19, 2]]) tensor([28], dtype=torch.int32) .. GENERATED FROM PYTHON SOURCE LINES 133-140 The ``processor`` object takes either a text or list of texts as inputs. When a list of texts are provided, the returned ``lengths`` variable represents the valid length of each processed tokens in the output batch. The intermediate representation can be retrieved as follow. .. GENERATED FROM PYTHON SOURCE LINES 140-144 .. code-block:: default print([processor.tokens[i] for i in processed[0, : lengths[0]]]) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', ' ', 't', 'e', 'x', 't', ' ', 't', 'o', ' ', 's', 'p', 'e', 'e', 'c', 'h', '!'] .. GENERATED FROM PYTHON SOURCE LINES 145-164 Phoneme-based encoding ~~~~~~~~~~~~~~~~~~~~~~ Phoneme-based encoding is similar to character-based encoding, but it uses a symbol table based on phonemes and a G2P (Grapheme-to-Phoneme) model. The detail of the G2P model is out of scope of this tutorial, we will just look at what the conversion looks like. Similar to the case of character-based encoding, the encoding process is expected to match what a pretrained Tacotron2 model is trained on. ``torchaudio`` has an interface to create the process. The following code illustrates how to make and use the process. Behind the scene, a G2P model is created using ``DeepPhonemizer`` package, and the pretrained weights published by the author of ``DeepPhonemizer`` is fetched. .. GENERATED FROM PYTHON SOURCE LINES 164-177 .. code-block:: default bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH processor = bundle.get_text_processor() text = "Hello world! Text to speech!" with torch.inference_mode(): processed, lengths = processor(text) print(processed) print(lengths) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 0%| | 0.00/63.6M [00:00`__. It is easy to instantiate a Tacotron2 model with pretrained weight, however, note that the input to Tacotron2 models need to be processed by the matching text processor. :py:func:`torchaudio.pipelines.Tacotron2TTSBundle` bundles the matching models and processors together so that it is easy to create the pipeline. For the available bundles, and its usage, please refer to :py:mod:`torchaudio.pipelines`. .. GENERATED FROM PYTHON SOURCE LINES 204-221 .. code-block:: default bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH processor = bundle.get_text_processor() tacotron2 = bundle.get_tacotron2().to(device) text = "Hello world! Text to speech!" with torch.inference_mode(): processed, lengths = processor(text) processed = processed.to(device) lengths = lengths.to(device) spec, _, _ = tacotron2.infer(processed, lengths) plt.imshow(spec[0].cpu().detach()) .. image-sg:: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_001.png :alt: tacotron2 pipeline tutorial :srcset: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading: "https://download.pytorch.org/torchaudio/models/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth 0%| | 0.00/107M [00:00 .. GENERATED FROM PYTHON SOURCE LINES 222-225 Note that ``Tacotron2.infer`` method perfoms multinomial sampling, therefor, the process of generating the spectrogram incurs randomness. .. GENERATED FROM PYTHON SOURCE LINES 225-235 .. code-block:: default fig, ax = plt.subplots(3, 1, figsize=(16, 4.3 * 3)) for i in range(3): with torch.inference_mode(): spec, spec_lengths, _ = tacotron2.infer(processed, lengths) print(spec[0].shape) ax[i].imshow(spec[0].cpu().detach()) plt.show() .. image-sg:: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_002.png :alt: tacotron2 pipeline tutorial :srcset: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none torch.Size([80, 155]) torch.Size([80, 167]) torch.Size([80, 164]) .. GENERATED FROM PYTHON SOURCE LINES 236-245 Waveform Generation ------------------- Once the spectrogram is generated, the last process is to recover the waveform from the spectrogram. ``torchaudio`` provides vocoders based on ``GriffinLim`` and ``WaveRNN``. .. GENERATED FROM PYTHON SOURCE LINES 248-254 WaveRNN ~~~~~~~ Continuing from the previous section, we can instantiate the matching WaveRNN model from the same bundle. .. GENERATED FROM PYTHON SOURCE LINES 254-278 .. code-block:: default bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH processor = bundle.get_text_processor() tacotron2 = bundle.get_tacotron2().to(device) vocoder = bundle.get_vocoder().to(device) text = "Hello world! Text to speech!" with torch.inference_mode(): processed, lengths = processor(text) processed = processed.to(device) lengths = lengths.to(device) spec, spec_lengths, _ = tacotron2.infer(processed, lengths) waveforms, lengths = vocoder(spec, spec_lengths) fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9)) ax1.imshow(spec[0].cpu().detach()) ax2.plot(waveforms[0].cpu().detach()) torchaudio.save("_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate) IPython.display.Audio("_assets/output_wavernn.wav") .. image-sg:: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_003.png :alt: tacotron2 pipeline tutorial :srcset: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading: "https://download.pytorch.org/torchaudio/models/wavernn_10k_epochs_8bits_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/wavernn_10k_epochs_8bits_ljspeech.pth 0%| | 0.00/16.7M [00:00

.. GENERATED FROM PYTHON SOURCE LINES 279-285 Griffin-Lim ~~~~~~~~~~~ Using the Griffin-Lim vocoder is same as WaveRNN. You can instantiate the vocode object with ``get_vocoder`` method and pass the spectrogram. .. GENERATED FROM PYTHON SOURCE LINES 285-311 .. code-block:: default bundle = torchaudio.pipelines.TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH processor = bundle.get_text_processor() tacotron2 = bundle.get_tacotron2().to(device) vocoder = bundle.get_vocoder().to(device) with torch.inference_mode(): processed, lengths = processor(text) processed = processed.to(device) lengths = lengths.to(device) spec, spec_lengths, _ = tacotron2.infer(processed, lengths) waveforms, lengths = vocoder(spec, spec_lengths) fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9)) ax1.imshow(spec[0].cpu().detach()) ax2.plot(waveforms[0].cpu().detach()) torchaudio.save( "_assets/output_griffinlim.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate, ) IPython.display.Audio("_assets/output_griffinlim.wav") .. image-sg:: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_004.png :alt: tacotron2 pipeline tutorial :srcset: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading: "https://download.pytorch.org/torchaudio/models/tacotron2_english_phonemes_1500_epochs_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/tacotron2_english_phonemes_1500_epochs_ljspeech.pth 0%| | 0.00/107M [00:00

.. GENERATED FROM PYTHON SOURCE LINES 312-319 Waveglow ~~~~~~~~ Waveglow is a vocoder published by Nvidia. The pretrained weight is publishe on Torch Hub. One can instantiate the model using ``torch.hub`` module. .. GENERATED FROM PYTHON SOURCE LINES 319-349 .. code-block:: default # Workaround to load model mapped on GPU # https://stackoverflow.com/a/61840832 waveglow = torch.hub.load( "NVIDIA/DeepLearningExamples:torchhub", "nvidia_waveglow", model_math="fp32", pretrained=False, ) checkpoint = torch.hub.load_state_dict_from_url( "https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth", # noqa: E501 progress=False, map_location=device, ) state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()} waveglow.load_state_dict(state_dict) waveglow = waveglow.remove_weightnorm(waveglow) waveglow = waveglow.to(device) waveglow.eval() with torch.no_grad(): waveforms = waveglow.infer(spec) fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9)) ax1.imshow(spec[0].cpu().detach()) ax2.plot(waveforms[0].cpu().detach()) torchaudio.save("_assets/output_waveglow.wav", waveforms[0:1].cpu(), sample_rate=22050) IPython.display.Audio("_assets/output_waveglow.wav") .. image-sg:: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_005.png :alt: tacotron2 pipeline tutorial :srcset: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /usr/local/envs/python3.8/lib/python3.8/site-packages/torch/hub.py:266: UserWarning: You are about to download and run code from an untrusted repository. In a future release, this won't be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, or load(..., trust_repo=True), which will assume that the prompt is to be answered with 'yes'. You can also use load(..., trust_repo='check') which will only prompt for confirmation if the repo is not already trusted. This will eventually be the default behaviour warnings.warn( Downloading: "https://github.com/NVIDIA/DeepLearningExamples/zipball/torchhub" to /root/.cache/torch/hub/torchhub.zip /root/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub/PyTorch/Classification/ConvNets/image_classification/models/common.py:13: UserWarning: pytorch_quantization module not found, quantization will not be available warnings.warn( /root/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub/PyTorch/Classification/ConvNets/image_classification/models/efficientnet.py:17: UserWarning: pytorch_quantization module not found, quantization will not be available warnings.warn( Downloading: "https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth" to /root/.cache/torch/hub/checkpoints/nvidia_waveglowpyt_fp32_20190306.pth .. raw:: html


.. rst-class:: sphx-glr-timing **Total running time of the script:** ( 5 minutes 32.190 seconds) .. _sphx_glr_download_tutorials_tacotron2_pipeline_tutorial.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tacotron2_pipeline_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tacotron2_pipeline_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_