• Docs >
  • Text-to-Speech with Tacotron2
Shortcuts

Text-to-Speech with Tacotron2

Author Yao-Yuan Yang, Moto Hira

import IPython
import matplotlib
import matplotlib.pyplot as plt

Overview

This tutorial shows how to build text-to-speech pipeline, using the pretrained Tacotron2 in torchaudio.

The text-to-speech pipeline goes as follows:

  1. Text preprocessing

    First, the input text is encoded into a list of symbols. In this tutorial, we will use English characters and phonemes as the symbols.

  2. Spectrogram generation

    From the encoded text, a spectrogram is generated. We use Tacotron2 model for this.

  3. Time-domain conversion

    The last step is converting the spectrogram into the waveform. The process to generate speech from spectrogram is also called Vocoder. In this tutorial, three different vocoders are used, WaveRNN, Griffin-Lim, and Nvidia’s WaveGlow.

The following figure illustrates the whole process.

https://download.pytorch.org/torchaudio/tutorial-assets/tacotron2_tts_pipeline.png

All the related components are bundled in torchaudio.pipelines.Tacotron2TTSBundle(), but this tutorial will also cover the process under the hood.

Preparation

First, we install the necessary dependencies. In addition to torchaudio, DeepPhonemizer is required to perform phoneme-based encoding.

# When running this example in notebook, install DeepPhonemizer
# !pip3 install deep_phonemizer

import torch
import torchaudio

matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]

torch.random.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

print(torch.__version__)
print(torchaudio.__version__)
print(device)

Out:

1.11.0+cpu
0.11.0+cpu
cpu

Text Processing

Character-based encoding

In this section, we will go through how the character-based encoding works.

Since the pre-trained Tacotron2 model expects specific set of symbol tables, the same functionalities available in torchaudio. This section is more for the explanation of the basis of encoding.

Firstly, we define the set of symbols. For example, we can use '_-!\'(),.:;? abcdefghijklmnopqrstuvwxyz'. Then, we will map the each character of the input text into the index of the corresponding symbol in the table.

The following is an example of such processing. In the example, symbols that are not in the table are ignored.

symbols = "_-!'(),.:;? abcdefghijklmnopqrstuvwxyz"
look_up = {s: i for i, s in enumerate(symbols)}
symbols = set(symbols)


def text_to_sequence(text):
    text = text.lower()
    return [look_up[s] for s in text if s in symbols]


text = "Hello world! Text to speech!"
print(text_to_sequence(text))

Out:

[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2, 11, 31, 16, 35, 31, 11, 31, 26, 11, 30, 27, 16, 16, 14, 19, 2]

As mentioned in the above, the symbol table and indices must match what the pretrained Tacotron2 model expects. torchaudio provides the transform along with the pretrained model. For example, you can instantiate and use such transform as follow.

processor = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor()

text = "Hello world! Text to speech!"
processed, lengths = processor(text)

print(processed)
print(lengths)

Out:

tensor([[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15,  2, 11, 31, 16, 35, 31, 11,
         31, 26, 11, 30, 27, 16, 16, 14, 19,  2]])
tensor([28], dtype=torch.int32)

The processor object takes either a text or list of texts as inputs. When a list of texts are provided, the returned lengths variable represents the valid length of each processed tokens in the output batch.

The intermediate representation can be retrieved as follow.

print([processor.tokens[i] for i in processed[0, : lengths[0]]])

Out:

['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', ' ', 't', 'e', 'x', 't', ' ', 't', 'o', ' ', 's', 'p', 'e', 'e', 'c', 'h', '!']

Phoneme-based encoding

Phoneme-based encoding is similar to character-based encoding, but it uses a symbol table based on phonemes and a G2P (Grapheme-to-Phoneme) model.

The detail of the G2P model is out of scope of this tutorial, we will just look at what the conversion looks like.

Similar to the case of character-based encoding, the encoding process is expected to match what a pretrained Tacotron2 model is trained on. torchaudio has an interface to create the process.

The following code illustrates how to make and use the process. Behind the scene, a G2P model is created using DeepPhonemizer package, and the pretrained weights published by the author of DeepPhonemizer is fetched.

bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH

processor = bundle.get_text_processor()

text = "Hello world! Text to speech!"
with torch.inference_mode():
    processed, lengths = processor(text)

print(processed)
print(lengths)

Out:

  0%|          | 0.00/63.6M [00:00<?, ?B/s]
  0%|          | 56.0k/63.6M [00:00<03:27, 322kB/s]
  0%|          | 208k/63.6M [00:00<01:43, 645kB/s]
  1%|1         | 696k/63.6M [00:00<00:40, 1.63MB/s]
  4%|3         | 2.52M/63.6M [00:00<00:12, 5.24MB/s]
  7%|6         | 4.27M/63.6M [00:00<00:08, 7.05MB/s]
 10%|9         | 6.07M/63.6M [00:01<00:07, 8.23MB/s]
 12%|#2        | 7.91M/63.6M [00:01<00:06, 9.08MB/s]
 15%|#5        | 9.79M/63.6M [00:01<00:05, 9.69MB/s]
 18%|#8        | 11.7M/63.6M [00:01<00:05, 10.2MB/s]
 21%|##1       | 13.7M/63.6M [00:01<00:04, 10.6MB/s]
 25%|##4       | 15.7M/63.6M [00:01<00:04, 10.9MB/s]
 28%|##7       | 17.7M/63.6M [00:02<00:04, 11.2MB/s]
 31%|###1      | 19.7M/63.6M [00:02<00:03, 12.2MB/s]
 33%|###2      | 20.9M/63.6M [00:02<00:03, 11.6MB/s]
 36%|###5      | 22.9M/63.6M [00:02<00:03, 12.5MB/s]
 38%|###7      | 24.1M/63.6M [00:02<00:03, 11.7MB/s]
 41%|####      | 26.1M/63.6M [00:02<00:03, 12.7MB/s]
 43%|####2     | 27.3M/63.6M [00:02<00:03, 11.9MB/s]
 46%|####6     | 29.3M/63.6M [00:03<00:02, 13.0MB/s]
 48%|####8     | 30.5M/63.6M [00:03<00:02, 12.1MB/s]
 51%|#####1    | 32.6M/63.6M [00:03<00:02, 13.2MB/s]
 53%|#####3    | 33.9M/63.6M [00:03<00:02, 12.3MB/s]
 57%|#####6    | 36.0M/63.6M [00:03<00:02, 13.4MB/s]
 59%|#####8    | 37.2M/63.6M [00:03<00:02, 12.5MB/s]
 62%|######1   | 39.4M/63.6M [00:03<00:01, 13.6MB/s]
 64%|######3   | 40.7M/63.6M [00:04<00:01, 12.7MB/s]
 67%|######7   | 42.8M/63.6M [00:04<00:01, 13.9MB/s]
 69%|######9   | 44.1M/63.6M [00:04<00:01, 12.8MB/s]
 73%|#######2  | 46.3M/63.6M [00:04<00:01, 14.0MB/s]
 75%|#######4  | 47.6M/63.6M [00:04<00:01, 13.0MB/s]
 78%|#######8  | 49.8M/63.6M [00:04<00:01, 14.1MB/s]
 80%|########  | 51.2M/63.6M [00:04<00:00, 13.1MB/s]
 84%|########3 | 53.4M/63.6M [00:04<00:00, 14.3MB/s]
 86%|########5 | 54.7M/63.6M [00:05<00:00, 13.2MB/s]
 89%|########9 | 56.9M/63.6M [00:05<00:00, 14.4MB/s]
 92%|#########1| 58.3M/63.6M [00:05<00:00, 13.4MB/s]
 95%|#########5| 60.5M/63.6M [00:05<00:00, 14.5MB/s]
 97%|#########7| 61.9M/63.6M [00:05<00:00, 13.4MB/s]
100%|##########| 63.6M/63.6M [00:05<00:00, 11.6MB/s]
tensor([[54, 20, 65, 69, 11, 92, 44, 65, 38,  2, 11, 81, 40, 64, 79, 81, 11, 81,
         20, 11, 79, 77, 59, 37,  2]])
tensor([25], dtype=torch.int32)

Notice that the encoded values are different from the example of character-based encoding.

The intermediate representation looks like the following.

print([processor.tokens[i] for i in processed[0, : lengths[0]]])

Out:

['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!', ' ', 'T', 'EH', 'K', 'S', 'T', ' ', 'T', 'AH', ' ', 'S', 'P', 'IY', 'CH', '!']

Spectrogram Generation

Tacotron2 is the model we use to generate spectrogram from the encoded text. For the detail of the model, please refer to the paper.

It is easy to instantiate a Tacotron2 model with pretrained weight, however, note that the input to Tacotron2 models need to be processed by the matching text processor.

torchaudio.pipelines.Tacotron2TTSBundle() bundles the matching models and processors together so that it is easy to create the pipeline.

For the available bundles, and its usage, please refer to torchaudio.pipelines.

bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)

text = "Hello world! Text to speech!"

with torch.inference_mode():
    processed, lengths = processor(text)
    processed = processed.to(device)
    lengths = lengths.to(device)
    spec, _, _ = tacotron2.infer(processed, lengths)


plt.imshow(spec[0].cpu().detach())
tacotron2 pipeline tutorial

Out:

Downloading: "https://download.pytorch.org/torchaudio/models/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth

  0%|          | 0.00/107M [00:00<?, ?B/s]
 21%|##1       | 22.9M/107M [00:00<00:00, 240MB/s]
 43%|####2     | 46.1M/107M [00:00<00:00, 242MB/s]
 64%|######4   | 69.2M/107M [00:00<00:00, 241MB/s]
 90%|######### | 96.8M/107M [00:00<00:00, 260MB/s]
100%|##########| 107M/107M [00:00<00:00, 257MB/s]

<matplotlib.image.AxesImage object at 0x7fce43221910>

Note that Tacotron2.infer method perfoms multinomial sampling, therefor, the process of generating the spectrogram incurs randomness.

fig, ax = plt.subplots(3, 1, figsize=(16, 4.3 * 3))
for i in range(3):
    with torch.inference_mode():
        spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
    print(spec[0].shape)
    ax[i].imshow(spec[0].cpu().detach())
plt.show()
tacotron2 pipeline tutorial

Out:

torch.Size([80, 155])
torch.Size([80, 167])
torch.Size([80, 164])

Waveform Generation

Once the spectrogram is generated, the last process is to recover the waveform from the spectrogram.

torchaudio provides vocoders based on GriffinLim and WaveRNN.

WaveRNN

Continuing from the previous section, we can instantiate the matching WaveRNN model from the same bundle.

bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH

processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)
vocoder = bundle.get_vocoder().to(device)

text = "Hello world! Text to speech!"

with torch.inference_mode():
    processed, lengths = processor(text)
    processed = processed.to(device)
    lengths = lengths.to(device)
    spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
    waveforms, lengths = vocoder(spec, spec_lengths)

fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach())

torchaudio.save("_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate)
IPython.display.Audio("_assets/output_wavernn.wav")
tacotron2 pipeline tutorial

Out:

Downloading: "https://download.pytorch.org/torchaudio/models/wavernn_10k_epochs_8bits_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/wavernn_10k_epochs_8bits_ljspeech.pth

  0%|          | 0.00/16.7M [00:00<?, ?B/s]
 29%|##9       | 4.87M/16.7M [00:00<00:00, 37.5MB/s]
 96%|#########5| 16.0M/16.7M [00:00<00:00, 70.5MB/s]
100%|##########| 16.7M/16.7M [00:00<00:00, 66.6MB/s]


Griffin-Lim

Using the Griffin-Lim vocoder is same as WaveRNN. You can instantiate the vocode object with get_vocoder method and pass the spectrogram.

bundle = torchaudio.pipelines.TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH

processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)
vocoder = bundle.get_vocoder().to(device)

with torch.inference_mode():
    processed, lengths = processor(text)
    processed = processed.to(device)
    lengths = lengths.to(device)
    spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(spec, spec_lengths)

fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach())

torchaudio.save(
    "_assets/output_griffinlim.wav",
    waveforms[0:1].cpu(),
    sample_rate=vocoder.sample_rate,
)
IPython.display.Audio("_assets/output_griffinlim.wav")
tacotron2 pipeline tutorial

Out:

Downloading: "https://download.pytorch.org/torchaudio/models/tacotron2_english_phonemes_1500_epochs_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/tacotron2_english_phonemes_1500_epochs_ljspeech.pth

  0%|          | 0.00/107M [00:00<?, ?B/s]
 16%|#5        | 16.9M/107M [00:00<00:00, 176MB/s]
 32%|###2      | 34.9M/107M [00:00<00:00, 183MB/s]
 49%|####8     | 52.3M/107M [00:00<00:00, 178MB/s]
 75%|#######5  | 81.1M/107M [00:00<00:00, 226MB/s]
100%|##########| 107M/107M [00:00<00:00, 228MB/s]


Waveglow

Waveglow is a vocoder published by Nvidia. The pretrained weight is publishe on Torch Hub. One can instantiate the model using torch.hub module.

# Workaround to load model mapped on GPU
# https://stackoverflow.com/a/61840832
waveglow = torch.hub.load(
    "NVIDIA/DeepLearningExamples:torchhub",
    "nvidia_waveglow",
    model_math="fp32",
    pretrained=False,
)
checkpoint = torch.hub.load_state_dict_from_url(
    "https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth",  # noqa: E501
    progress=False,
    map_location=device,
)
state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}

waveglow.load_state_dict(state_dict)
waveglow = waveglow.remove_weightnorm(waveglow)
waveglow = waveglow.to(device)
waveglow.eval()

with torch.no_grad():
    waveforms = waveglow.infer(spec)

fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach())

torchaudio.save("_assets/output_waveglow.wav", waveforms[0:1].cpu(), sample_rate=22050)
IPython.display.Audio("_assets/output_waveglow.wav")
tacotron2 pipeline tutorial

Out:

Downloading: "https://github.com/NVIDIA/DeepLearningExamples/archive/torchhub.zip" to /root/.cache/torch/hub/torchhub.zip
/root/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub/PyTorch/Classification/ConvNets/image_classification/models/common.py:13: UserWarning: pytorch_quantization module not found, quantization will not be available
  warnings.warn(
/root/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub/PyTorch/Classification/ConvNets/image_classification/models/efficientnet.py:17: UserWarning: pytorch_quantization module not found, quantization will not be available
  warnings.warn(
/root/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub/PyTorch/SpeechSynthesis/Tacotron2/waveglow/model.py:55: UserWarning: torch.qr is deprecated in favor of torch.linalg.qr and will be removed in a future PyTorch release.
The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1980.)
  W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
Downloading: "https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth" to /root/.cache/torch/hub/checkpoints/nvidia_waveglowpyt_fp32_20190306.pth


Total running time of the script: ( 3 minutes 5.689 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources