• Docs >
  • Module code >
  • torchaudio.prototype.models._emformer_hubert >
  • Nightly (unstable)
Shortcuts

Source code for torchaudio.prototype.models._emformer_hubert

from typing import List, Optional, Tuple

import torch
from torchaudio.models import Wav2Vec2Model
from torchaudio.models.emformer import Emformer
from torchaudio.models.rnnt import _TimeReduction


class FeatureEncoder(torch.nn.Module):
    """Extract features from log-mel spectrogram input. Consists of linear layer and time reduction layer.

    Args:
        input_dim (int): The feature dimension of log-mel spectrogram feature.
        output_dim (int): The feature dimension after linear layer.
        use_bias (bool): If ``True``, enable bias parameter in the linear layer.
        stride (int): Number of frames to merge for the output frame.
    """

    def __init__(self, input_dim: int, output_dim: int, use_bias: bool, stride: int):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim, bias=use_bias)
        self.time_reduction = _TimeReduction(stride)

    def forward(
        self, input: torch.Tensor, lengths: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Args:
            input (torch.Tensor): The log-mel spectrogram input.
                Tensor with dimensions `(batch, time, input_dim)`.
            lengths (torch.Tensor or None): Valid length of each input sample.
                Tensor with dimension `(batch, )`.

        Returns:
            (torch.Tensor, torch.Tensor or None):
                torch.Tensor
                    Returned feature Tensor after linear layer and time reduction layer.
                    Tensor with dimensions `(batch, time // stride, output_dim)`.
                torch.Tensor or None
                    The reduced lengths Tensor.
        """
        output = self.linear(input)
        if lengths is None:
            B, T, _ = input.shape
            dummy_lengths = torch.full((B,), T)
            output, _ = self.time_reduction(output, dummy_lengths)
        else:
            output, lengths = self.time_reduction(output, lengths)
        return output, lengths


class EmformerEncoder(torch.nn.Module):
    """Emformer Encoder class for HuBERT pre-training. Consists of emformer module,
        linear layer and layer normalization layer.

    Args:
        emformer (torch.nn.Module):
            :py:class:`torchaudio.models.Emformer` module that consists of a list of emformer layers.
        output_linear (torch.nn.Module):
            Linear layer after emformer module.
        layer_norm (torch.nn.Module):
            Apply layer normalization to the output.
    """

    def __init__(
        self,
        emformer: torch.nn.Module,
        output_linear: torch.nn.Module,
        layer_norm: torch.nn.Module,
    ):
        super().__init__()
        self.emformer = emformer
        self.output_linear = output_linear
        self.layer_norm = layer_norm

    def forward(
        self,
        input: torch.Tensor,
        lengths: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """
        Args:
            input (torch.Tensor): The input feature for emformer encoder.
                Tensor with dimensions `(batch, time, feature_dim)`.
            lengths (torch.Tensor or None): Valid length of each input sample.
                Tensor with dimension `(batch, )`.

        Returns:
            torch.Tensor: The feature Tensor after emformer encoder.
        """
        if lengths is None:
            B, T, _ = input.shape
            dummy_lengths = torch.full((B,), T)
            output, _ = self.emformer(input, dummy_lengths)
        else:
            output, lengths = self.emformer(input, lengths)
        output = self.output_linear(output)
        output = self.layer_norm(output)
        return output

    def extract_features(
        self,
        input: torch.Tensor,
        lengths: Optional[torch.Tensor],
        num_layers: Optional[int] = None,
    ) -> List[torch.Tensor]:
        """Extract output Tensors of the emformer layers.

        Args:
            input (torch.Tensor): The input feature for emformer encoder.
                Tensor with dimensions `(batch, time, feature_dim)`.
            lengths (torch.Tensor or None): Valid length of each input sample.
                Tensor with dimension `(batch, )`.
            num_layers (int or None, optional): If not ``None``, returns the first
                `num_layers` layers of Tensors as the output, otherwise returns the
                Tensors from all emformer layers.

        Returns:
            List[torch.Tensor]:
                Output Tensors of selected emformer layers.
        """
        if num_layers is not None:
            if not 0 < num_layers <= len(self.emformer.emformer_layers):
                raise ValueError(f"`num_layers` must be between [1, {len(self.emformer.emformer_layers)}]")

        ret: List[torch.Tensor] = []

        input = input.permute(1, 0, 2)
        right_context = self.emformer._gen_right_context(input)
        utterance = input[: input.size(0) - self.emformer.right_context_length]
        attention_mask = self.emformer._gen_attention_mask(utterance)
        mems = (
            self.emformer.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1]
            if self.emformer.use_mem
            else torch.empty(0).to(dtype=input.dtype, device=input.device)
        )
        output = utterance
        if lengths is None:
            B, T, _ = input.shape
            lengths = torch.full((B,), T)
        for layer in self.emformer.emformer_layers:
            output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
            ret.append(output.permute(1, 0, 2))
            if num_layers is not None and len(ret) >= num_layers:
                return ret
        return ret


def _get_emformer_feature_extractor(input_dim: int, output_dim: int, use_bias: bool, stride: int) -> FeatureEncoder:
    """Construct FeatureEncoder for emformer model.

    Args:
        input_dim (int): The feature dimension of log-mel spectrogram feature.
        output_dim (int): The feature dimension after linear layer.
        use_bias (bool): If ``True``, enable bias parameter in the linear layer.
        stride (int): Number of frames to merge for the output frame.

    Returns:
        FeatureEncoder: The resulting FeatureEncoder module.
    """
    return FeatureEncoder(input_dim, output_dim, use_bias, stride)


def _get_emformer_encoder(
    input_dim: int,
    output_dim: int,
    num_heads: int,
    ffn_dim: int,
    num_layers: int,
    segment_length: int,
    left_context_length: int,
    right_context_length: int,
    dropout: float,
    activation: str,
    max_memory_size: int,
    weight_init_scale_strategy: Optional[str],
    tanh_on_mem: bool,
) -> EmformerEncoder:
    """Construct EmformerEncoder for emformer model.

    Args:
        input_dim (int): The feature dimension of input Tensor.
        output_dim (int): The feature dimension after EmformerEncoder.
        num_heads (int): Number of attention heads in each Emformer layer.
        ffn_dim: (int): Hidden layer dimension of feedforward network.
        num_layers (int): Number of Emformer layers to instantiate.
        segment_length (int): Length of each input segment.
        left_context_length (int): Length of left context.
        right_context_length (int): Length of right context.
        dropout (float): Dropout probability.
        activation (str): Activation function to use in each Emformer layer's
            feedforward network. Must be one of ("relu", "gelu", "silu").
        max_memory_size (int): Maximum number of memory elements to use.
        weight_init_scale_strategy (str or None): Per-layer weight initialization scaling
            strategy. Must be one of ("depthwise", "constant", ``None``).
        tanh_on_mem (bool): If ``True``, applies tanh to memory elements.

    Returns:
        EmformerEncoder: The resulting EmformerEncoder module.
    """
    emformer = Emformer(
        input_dim=input_dim,
        num_heads=num_heads,
        ffn_dim=ffn_dim,
        num_layers=num_layers,
        segment_length=segment_length,
        left_context_length=left_context_length,
        right_context_length=right_context_length,
        dropout=dropout,
        activation=activation,
        max_memory_size=max_memory_size,
        weight_init_scale_strategy=weight_init_scale_strategy,
        tanh_on_mem=tanh_on_mem,
    )
    output_linear = torch.nn.Linear(input_dim, output_dim)
    layer_norm = torch.nn.LayerNorm(output_dim)
    return EmformerEncoder(emformer, output_linear, layer_norm)


[docs]def emformer_hubert_model( extractor_input_dim: int, extractor_output_dim: int, extractor_use_bias: bool, extractor_stride: int, encoder_input_dim: int, encoder_output_dim: int, encoder_num_heads: int, encoder_ffn_dim: int, encoder_num_layers: int, encoder_segment_length: int, encoder_left_context_length: int, encoder_right_context_length: int, encoder_dropout: float, encoder_activation: str, encoder_max_memory_size: int, encoder_weight_init_scale_strategy: Optional[str], encoder_tanh_on_mem: bool, aux_num_out: Optional[int], ) -> Wav2Vec2Model: """Build a custom Emformer HuBERT model. Args: extractor_input_dim (int): The input dimension for feature extractor. extractor_output_dim (int): The output dimension after feature extractor. extractor_use_bias (bool): If ``True``, enable bias parameter in the linear layer of feature extractor. extractor_stride (int): Number of frames to merge for the output frame in feature extractor. encoder_input_dim (int): The input dimension for Emformer layer. encoder_output_dim (int): The output dimension after EmformerEncoder. encoder_num_heads (int): Number of attention heads in each Emformer layer. encoder_ffn_dim (int): Hidden layer dimension of feedforward network in Emformer. encoder_num_layers (int): Number of Emformer layers to instantiate. encoder_segment_length (int): Length of each input segment. encoder_left_context_length (int): Length of left context. encoder_right_context_length (int): Length of right context. encoder_dropout (float): Dropout probability. encoder_activation (str): Activation function to use in each Emformer layer's feedforward network. Must be one of ("relu", "gelu", "silu"). encoder_max_memory_size (int): Maximum number of memory elements to use. encoder_weight_init_scale_strategy (str or None): Per-layer weight initialization scaling strategy. Must be one of ("depthwise", "constant", ``None``). encoder_tanh_on_mem (bool): If ``True``, applies tanh to memory elements. aux_num_out (int or None): When provided, attach an extra linear layer on top of encoder, which can be used for fine-tuning. Returns: Wav2Vec2Model: The resulting :py:class:`torchaudio.models.Wav2Vec2Model` model with a :py:class:`torchaudio.models.Emformer` encoder. """ feature_extractor = _get_emformer_feature_extractor( extractor_input_dim, extractor_output_dim, extractor_use_bias, extractor_stride ) emformer = _get_emformer_encoder( encoder_input_dim, encoder_output_dim, encoder_num_heads, encoder_ffn_dim, encoder_num_layers, encoder_segment_length, encoder_left_context_length, encoder_right_context_length, encoder_dropout, encoder_activation, encoder_max_memory_size, encoder_weight_init_scale_strategy, encoder_tanh_on_mem, ) aux = None if aux_num_out is not None: aux = torch.nn.Linear(in_features=encoder_output_dim, out_features=aux_num_out) return Wav2Vec2Model(feature_extractor, emformer, aux)
[docs]def emformer_hubert_base( extractor_input_dim: int = 80, extractor_output_dim: int = 128, encoder_dropout: float = 0.1, aux_num_out: Optional[int] = None, ) -> Wav2Vec2Model: """Build Emformer HuBERT Model with 20 Emformer layers. Args: extractor_input_dim (int, optional): The input dimension for feature extractor. (Default: 80) extractor_output_dim (int, optional): The output dimension after feature extractor. (Default: 128) encoder_dropout (float, optional): Dropout probability in Emformer. (Default: 0.1) aux_num_out (int or None, optional): Output dimension of aux layer for fine-tuning. (Default: ``None``) Returns: Wav2Vec2Model: The resulting :py:class:`torchaudio.models.Wav2Vec2Model` model with a :py:class:`torchaudio.models.Emformer` encoder. """ return emformer_hubert_model( extractor_input_dim=extractor_input_dim, extractor_output_dim=extractor_output_dim, extractor_use_bias=False, extractor_stride=4, encoder_input_dim=512, encoder_output_dim=1024, encoder_num_heads=8, encoder_ffn_dim=2048, encoder_num_layers=20, encoder_segment_length=4, encoder_left_context_length=30, encoder_right_context_length=1, encoder_dropout=encoder_dropout, encoder_activation="gelu", encoder_max_memory_size=0, encoder_weight_init_scale_strategy="depthwise", encoder_tanh_on_mem=True, aux_num_out=aux_num_out, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources