Shortcuts

Source code for torchaudio.models.squim.objective

import math
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


def transform_wb_pesq_range(x: float) -> float:
    """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
    for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
    defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".

    Args:
        x (float): Narrow-band PESQ score.

    Returns:
        (float): Wide-band PESQ score.
    """
    return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))


PESQRange: Tuple[float, float] = (
    1.0,  # P.862.2 uses a different input filter than P.862, and the lower bound of
    # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
    # We are using 1.0 as a reasonable approximation.
    transform_wb_pesq_range(4.5),
)


class RangeSigmoid(nn.Module):
    def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
        super(RangeSigmoid, self).__init__()
        assert isinstance(val_range, tuple) and len(val_range) == 2
        self.val_range: Tuple[float, float] = val_range
        self.sigmoid: nn.modules.Module = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
        return out


class Encoder(nn.Module):
    """Encoder module that transform 1D waveform to 2D representations.

    Args:
        feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
        win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
    """

    def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
        super(Encoder, self).__init__()

        self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply waveforms to convolutional layer and ReLU layer.

        Args:
            x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.

        Returns:
            (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
        """
        out = x.unsqueeze(dim=1)
        out = F.relu(self.conv1d(out))
        return out


class SingleRNN(nn.Module):
    def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
        super(SingleRNN, self).__init__()

        self.rnn_type = rnn_type
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
            input_size,
            hidden_size,
            1,
            dropout=dropout,
            batch_first=True,
            bidirectional=True,
        )

        self.proj = nn.Linear(hidden_size * 2, input_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input shape: batch, seq, dim
        out, _ = self.rnn(x)
        out = self.proj(out)
        return out


class DPRNN(nn.Module):
    """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.

    Args:
        feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
        hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
        num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
        rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
        d_model (int, optional): The number of expected features in the input. (Default: 256)
        chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
        chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
    """

    def __init__(
        self,
        feat_dim: int = 64,
        hidden_dim: int = 128,
        num_blocks: int = 6,
        rnn_type: str = "LSTM",
        d_model: int = 256,
        chunk_size: int = 100,
        chunk_stride: int = 50,
    ) -> None:
        super(DPRNN, self).__init__()

        self.num_blocks = num_blocks

        self.row_rnn = nn.ModuleList([])
        self.col_rnn = nn.ModuleList([])
        self.row_norm = nn.ModuleList([])
        self.col_norm = nn.ModuleList([])
        for _ in range(num_blocks):
            self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
            self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
            self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
            self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
        self.conv = nn.Sequential(
            nn.Conv2d(feat_dim, d_model, 1),
            nn.PReLU(),
        )
        self.chunk_size = chunk_size
        self.chunk_stride = chunk_stride

    def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
        # input shape: (B, N, T)
        seq_len = x.shape[-1]

        rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
        out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])

        return out, rest

    def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
        out, rest = self.pad_chunk(x)
        batch_size, feat_dim, seq_len = out.shape

        segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
        segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
        out = torch.cat([segments1, segments2], dim=3)
        out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()

        return out, rest

    def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
        batch_size, dim, _, _ = x.shape
        out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
        out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
        out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
        out = out1 + out2
        if rest > 0:
            out = out[:, :, :-rest]
        out = out.contiguous()
        return out

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, rest = self.chunking(x)
        batch_size, _, dim1, dim2 = x.shape
        out = x
        for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
            row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
            row_out = row_rnn(row_in)
            row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
            row_out = row_norm(row_out)
            out = out + row_out

            col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
            col_out = col_rnn(col_in)
            col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
            col_out = col_norm(col_out)
            out = out + col_out
        out = self.conv(out)
        out = self.merging(out, rest)
        out = out.transpose(1, 2).contiguous()
        return out


class AutoPool(nn.Module):
    def __init__(self, pool_dim: int = 1) -> None:
        super(AutoPool, self).__init__()
        self.pool_dim: int = pool_dim
        self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
        self.register_parameter("alpha", nn.Parameter(torch.ones(1)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        weight = self.softmax(torch.mul(x, self.alpha))
        out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
        return out


[docs]class SquimObjective(nn.Module): """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores for speech enhancement (e.g., STOI, PESQ, and SI-SDR). Args: encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation. dprnn (torch.nn.Module): DPRNN module to model sequential feature. branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score. """ def __init__( self, encoder: nn.Module, dprnn: nn.Module, branches: nn.ModuleList, ): super(SquimObjective, self).__init__() self.encoder = encoder self.dprnn = dprnn self.branches = branches
[docs] def forward(self, x: torch.Tensor) -> List[torch.Tensor]: """ Args: x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`. Returns: List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`. """ if x.ndim != 2: raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.") x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20) out = self.encoder(x) out = self.dprnn(out) scores = [] for branch in self.branches: scores.append(branch(out).squeeze(dim=1)) return scores
def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module: """Create branch module after DPRNN model for predicting metric score. Args: d_model (int): The number of expected features in the input. nhead (int): Number of heads in the multi-head attention model. metric (str): The metric name to predict. Returns: (nn.Module): Returned module to predict corresponding metric score. """ layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True) layer2 = AutoPool() if metric == "stoi": layer3 = nn.Sequential( nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1), RangeSigmoid(), ) elif metric == "pesq": layer3 = nn.Sequential( nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1), RangeSigmoid(val_range=PESQRange), ) else: layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)) return nn.Sequential(layer1, layer2, layer3)
[docs]def squim_objective_model( feat_dim: int, win_len: int, d_model: int, nhead: int, hidden_dim: int, num_blocks: int, rnn_type: str, chunk_size: int, chunk_stride: Optional[int] = None, ) -> SquimObjective: """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model. Args: feat_dim (int, optional): The feature dimension after Encoder module. win_len (int): Kernel size in the Encoder module. d_model (int): The number of expected features in the input. nhead (int): Number of heads in the multi-head attention model. hidden_dim (int): Hidden dimension in the RNN layer of DPRNN. num_blocks (int): Number of DPRNN layers. rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. chunk_size (int): Chunk size of input for DPRNN. chunk_stride (int or None, optional): Stride of chunk input for DPRNN. """ if chunk_stride is None: chunk_stride = chunk_size // 2 encoder = Encoder(feat_dim, win_len) dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride) branches = nn.ModuleList( [ _create_branch(d_model, nhead, "stoi"), _create_branch(d_model, nhead, "pesq"), _create_branch(d_model, nhead, "sisdr"), ] ) return SquimObjective(encoder, dprnn, branches)
[docs]def squim_objective_base() -> SquimObjective: """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments.""" return squim_objective_model( feat_dim=256, win_len=64, d_model=256, nhead=4, hidden_dim=256, num_blocks=2, rnn_type="LSTM", chunk_size=71, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources