.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/tacotron2_pipeline_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_tacotron2_pipeline_tutorial.py: Text-to-Speech with Tacotron2 ============================= **Author**: `Yao-Yuan Yang `__, `Moto Hira `__ .. GENERATED FROM PYTHON SOURCE LINES 11-45 Overview -------- This tutorial shows how to build text-to-speech pipeline, using the pretrained Tacotron2 in torchaudio. The text-to-speech pipeline goes as follows: 1. Text preprocessing First, the input text is encoded into a list of symbols. In this tutorial, we will use English characters and phonemes as the symbols. 2. Spectrogram generation From the encoded text, a spectrogram is generated. We use the ``Tacotron2`` model for this. 3. Time-domain conversion The last step is converting the spectrogram into the waveform. The process to generate speech from spectrogram is also called a Vocoder. In this tutorial, three different vocoders are used, :py:class:`~torchaudio.models.WaveRNN`, :py:class:`~torchaudio.transforms.GriffinLim`, and `Nvidia's WaveGlow `__. The following figure illustrates the whole process. .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/tacotron2_tts_pipeline.png All the related components are bundled in :py:class:`torchaudio.pipelines.Tacotron2TTSBundle`, but this tutorial will also cover the process under the hood. .. GENERATED FROM PYTHON SOURCE LINES 47-54 Preparation ----------- First, we install the necessary dependencies. In addition to ``torchaudio``, ``DeepPhonemizer`` is required to perform phoneme-based encoding. .. GENERATED FROM PYTHON SOURCE LINES 56-60 .. code-block:: bash %%bash pip3 install deep_phonemizer .. GENERATED FROM PYTHON SOURCE LINES 60-72 .. code-block:: default import torch import torchaudio torch.random.manual_seed(0) device = "cuda" if torch.cuda.is_available() else "cpu" print(torch.__version__) print(torchaudio.__version__) print(device) .. rst-class:: sphx-glr-script-out .. code-block:: none 2.4.0.dev20240423 2.2.0.dev20240424 cuda .. GENERATED FROM PYTHON SOURCE LINES 74-79 .. code-block:: default import IPython import matplotlib.pyplot as plt .. GENERATED FROM PYTHON SOURCE LINES 80-83 Text Processing --------------- .. GENERATED FROM PYTHON SOURCE LINES 86-100 Character-based encoding ~~~~~~~~~~~~~~~~~~~~~~~~ In this section, we will go through how the character-based encoding works. Since the pre-trained Tacotron2 model expects specific set of symbol tables, the same functionalities is available in ``torchaudio``. However, we will first manually implement the encoding to aid in understanding. First, we define the set of symbols ``'_-!\'(),.:;? abcdefghijklmnopqrstuvwxyz'``. Then, we will map the each character of the input text into the index of the corresponding symbol in the table. Symbols that are not in the table are ignored. .. GENERATED FROM PYTHON SOURCE LINES 100-115 .. code-block:: default symbols = "_-!'(),.:;? abcdefghijklmnopqrstuvwxyz" look_up = {s: i for i, s in enumerate(symbols)} symbols = set(symbols) def text_to_sequence(text): text = text.lower() return [look_up[s] for s in text if s in symbols] text = "Hello world! Text to speech!" print(text_to_sequence(text)) .. rst-class:: sphx-glr-script-out .. code-block:: none [19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2, 11, 31, 16, 35, 31, 11, 31, 26, 11, 30, 27, 16, 16, 14, 19, 2] .. GENERATED FROM PYTHON SOURCE LINES 116-121 As mentioned in the above, the symbol table and indices must match what the pretrained Tacotron2 model expects. ``torchaudio`` provides the same transform along with the pretrained model. You can instantiate and use such transform as follow. .. GENERATED FROM PYTHON SOURCE LINES 121-131 .. code-block:: default processor = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor() text = "Hello world! Text to speech!" processed, lengths = processor(text) print(processed) print(lengths) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2, 11, 31, 16, 35, 31, 11, 31, 26, 11, 30, 27, 16, 16, 14, 19, 2]]) tensor([28], dtype=torch.int32) .. GENERATED FROM PYTHON SOURCE LINES 132-139 Note: The output of our manual encoding and the ``torchaudio`` ``text_processor`` output matches (meaning we correctly re-implemented what the library does internally). It takes either a text or list of texts as inputs. When a list of texts are provided, the returned ``lengths`` variable represents the valid length of each processed tokens in the output batch. The intermediate representation can be retrieved as follows: .. GENERATED FROM PYTHON SOURCE LINES 139-143 .. code-block:: default print([processor.tokens[i] for i in processed[0, : lengths[0]]]) .. rst-class:: sphx-glr-script-out .. code-block:: none ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', ' ', 't', 'e', 'x', 't', ' ', 't', 'o', ' ', 's', 'p', 'e', 'e', 'c', 'h', '!'] .. GENERATED FROM PYTHON SOURCE LINES 144-163 Phoneme-based encoding ~~~~~~~~~~~~~~~~~~~~~~ Phoneme-based encoding is similar to character-based encoding, but it uses a symbol table based on phonemes and a G2P (Grapheme-to-Phoneme) model. The detail of the G2P model is out of the scope of this tutorial, we will just look at what the conversion looks like. Similar to the case of character-based encoding, the encoding process is expected to match what a pretrained Tacotron2 model is trained on. ``torchaudio`` has an interface to create the process. The following code illustrates how to make and use the process. Behind the scene, a G2P model is created using ``DeepPhonemizer`` package, and the pretrained weights published by the author of ``DeepPhonemizer`` is fetched. .. GENERATED FROM PYTHON SOURCE LINES 163-176 .. code-block:: default bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH processor = bundle.get_text_processor() text = "Hello world! Text to speech!" with torch.inference_mode(): processed, lengths = processor(text) print(processed) print(lengths) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0.00/63.6M [00:00`__. It is easy to instantiate a Tacotron2 model with pretrained weights, however, note that the input to Tacotron2 models need to be processed by the matching text processor. :py:class:`torchaudio.pipelines.Tacotron2TTSBundle` bundles the matching models and processors together so that it is easy to create the pipeline. For the available bundles, and its usage, please refer to :py:class:`~torchaudio.pipelines.Tacotron2TTSBundle`. .. GENERATED FROM PYTHON SOURCE LINES 204-221 .. code-block:: default bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH processor = bundle.get_text_processor() tacotron2 = bundle.get_tacotron2().to(device) text = "Hello world! Text to speech!" with torch.inference_mode(): processed, lengths = processor(text) processed = processed.to(device) lengths = lengths.to(device) spec, _, _ = tacotron2.infer(processed, lengths) _ = plt.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto") .. image-sg:: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_001.png :alt: tacotron2 pipeline tutorial :srcset: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/modules/transformer.py:306: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance) warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}") Downloading: "https://download.pytorch.org/torchaudio/models/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth 0%| | 0.00/107M [00:00

.. GENERATED FROM PYTHON SOURCE LINES 293-301 Griffin-Lim Vocoder ~~~~~~~~~~~~~~~~~~~ Using the Griffin-Lim vocoder is same as WaveRNN. You can instantiate the vocoder object with :py:func:`~torchaudio.pipelines.Tacotron2TTSBundle.get_vocoder` method and pass the spectrogram. .. GENERATED FROM PYTHON SOURCE LINES 301-315 .. code-block:: default bundle = torchaudio.pipelines.TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH processor = bundle.get_text_processor() tacotron2 = bundle.get_tacotron2().to(device) vocoder = bundle.get_vocoder().to(device) with torch.inference_mode(): processed, lengths = processor(text) processed = processed.to(device) lengths = lengths.to(device) spec, spec_lengths, _ = tacotron2.infer(processed, lengths) waveforms, lengths = vocoder(spec, spec_lengths) .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/modules/transformer.py:306: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance) warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}") Downloading: "https://download.pytorch.org/torchaudio/models/tacotron2_english_phonemes_1500_epochs_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/tacotron2_english_phonemes_1500_epochs_ljspeech.pth 0%| | 0.00/107M [00:00

.. GENERATED FROM PYTHON SOURCE LINES 322-329 Waveglow Vocoder ~~~~~~~~~~~~~~~~ Waveglow is a vocoder published by Nvidia. The pretrained weights are published on Torch Hub. One can instantiate the model using ``torch.hub`` module. .. GENERATED FROM PYTHON SOURCE LINES 329-353 .. code-block:: default # Workaround to load model mapped on GPU # https://stackoverflow.com/a/61840832 waveglow = torch.hub.load( "NVIDIA/DeepLearningExamples:torchhub", "nvidia_waveglow", model_math="fp32", pretrained=False, ) checkpoint = torch.hub.load_state_dict_from_url( "https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth", # noqa: E501 progress=False, map_location=device, ) state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()} waveglow.load_state_dict(state_dict) waveglow = waveglow.remove_weightnorm(waveglow) waveglow = waveglow.to(device) waveglow.eval() with torch.no_grad(): waveforms = waveglow.infer(spec) .. rst-class:: sphx-glr-script-out .. code-block:: none /pytorch/audio/ci_env/lib/python3.10/site-packages/torch/hub.py:293: UserWarning: You are about to download and run code from an untrusted repository. In a future release, this won't be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, or load(..., trust_repo=True), which will assume that the prompt is to be answered with 'yes'. You can also use load(..., trust_repo='check') which will only prompt for confirmation if the repo is not already trusted. This will eventually be the default behaviour warnings.warn( Downloading: "https://github.com/NVIDIA/DeepLearningExamples/zipball/torchhub" to /root/.cache/torch/hub/torchhub.zip /root/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub/PyTorch/Classification/ConvNets/image_classification/models/common.py:13: UserWarning: pytorch_quantization module not found, quantization will not be available warnings.warn( /root/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub/PyTorch/Classification/ConvNets/image_classification/models/efficientnet.py:17: UserWarning: pytorch_quantization module not found, quantization will not be available warnings.warn( /pytorch/audio/ci_env/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm. warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.") Downloading: "https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth" to /root/.cache/torch/hub/checkpoints/nvidia_waveglowpyt_fp32_20190306.pth .. GENERATED FROM PYTHON SOURCE LINES 355-357 .. code-block:: default plot(waveforms, spec, 22050) .. image-sg:: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_005.png :alt: tacotron2 pipeline tutorial :srcset: /tutorials/images/sphx_glr_tacotron2_pipeline_tutorial_005.png :class: sphx-glr-single-img .. raw:: html


.. rst-class:: sphx-glr-timing **Total running time of the script:** ( 1 minutes 17.877 seconds) .. _sphx_glr_download_tutorials_tacotron2_pipeline_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tacotron2_pipeline_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tacotron2_pipeline_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_