• Docs >
  • Module code >
  • torchaudio.prototype.models.conv_emformer >
  • Nightly (unstable)
Shortcuts

Source code for torchaudio.prototype.models.conv_emformer

import math
from typing import List, Optional, Tuple

import torch
from torchaudio.models.emformer import _EmformerAttention, _EmformerImpl, _get_weight_init_gains


def _get_activation_module(activation: str) -> torch.nn.Module:
    if activation == "relu":
        return torch.nn.ReLU()
    elif activation == "gelu":
        return torch.nn.GELU()
    elif activation == "silu":
        return torch.nn.SiLU()
    else:
        raise ValueError(f"Unsupported activation {activation}")


class _ResidualContainer(torch.nn.Module):
    def __init__(self, module: torch.nn.Module, output_weight: int):
        super().__init__()
        self.module = module
        self.output_weight = output_weight

    def forward(self, input: torch.Tensor):
        output = self.module(input)
        return output * self.output_weight + input


class _ConvolutionModule(torch.nn.Module):
    def __init__(
        self,
        input_dim: int,
        segment_length: int,
        right_context_length: int,
        kernel_size: int,
        activation: str = "silu",
        dropout: float = 0.0,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.segment_length = segment_length
        self.right_context_length = right_context_length
        self.state_size = kernel_size - 1

        self.pre_conv = torch.nn.Sequential(
            torch.nn.LayerNorm(input_dim), torch.nn.Linear(input_dim, 2 * input_dim, bias=True), torch.nn.GLU()
        )
        self.conv = torch.nn.Conv1d(
            in_channels=input_dim,
            out_channels=input_dim,
            kernel_size=kernel_size,
            stride=1,
            padding=0,
            groups=input_dim,
        )
        self.post_conv = torch.nn.Sequential(
            torch.nn.LayerNorm(input_dim),
            _get_activation_module(activation),
            torch.nn.Linear(input_dim, input_dim, bias=True),
            torch.nn.Dropout(p=dropout),
        )

    def _split_right_context(self, utterance: torch.Tensor, right_context: torch.Tensor) -> torch.Tensor:
        T, B, D = right_context.size()
        if T % self.right_context_length != 0:
            raise ValueError("Tensor length should be divisible by its right context length")
        num_segments = T // self.right_context_length
        # (num_segments, right context length, B, D)
        right_context_segments = right_context.reshape(num_segments, self.right_context_length, B, D)
        right_context_segments = right_context_segments.permute(0, 2, 1, 3).reshape(
            num_segments * B, self.right_context_length, D
        )

        pad_segments = []  # [(kernel_size - 1, B, D), ...]
        for seg_idx in range(num_segments):
            end_idx = min(self.state_size + (seg_idx + 1) * self.segment_length, utterance.size(0))
            start_idx = end_idx - self.state_size
            pad_segments.append(utterance[start_idx:end_idx, :, :])

        pad_segments = torch.cat(pad_segments, dim=1).permute(1, 0, 2)  # (num_segments * B, kernel_size - 1, D)
        return torch.cat([pad_segments, right_context_segments], dim=1).permute(0, 2, 1)

    def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor:
        # (num_segments * B, D, right_context_length)
        right_context = right_context.reshape(-1, B, self.input_dim, self.right_context_length)
        right_context = right_context.permute(0, 3, 1, 2)
        return right_context.reshape(-1, B, self.input_dim)  # (right_context_length * num_segments, B, D)

    def forward(
        self, utterance: torch.Tensor, right_context: torch.Tensor, state: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        input = torch.cat((right_context, utterance))  # input: (T, B, D)
        x = self.pre_conv(input)
        x_right_context, x_utterance = x[: right_context.size(0), :, :], x[right_context.size(0) :, :, :]
        x_utterance = x_utterance.permute(1, 2, 0)  # (B, D, T_utterance)

        if state is None:
            state = torch.zeros(
                input.size(1),
                input.size(2),
                self.state_size,
                device=input.device,
                dtype=input.dtype,
            )  # (B, D, T)
        state_x_utterance = torch.cat([state, x_utterance], dim=2)

        conv_utterance = self.conv(state_x_utterance)  # (B, D, T_utterance)
        conv_utterance = conv_utterance.permute(2, 0, 1)

        if self.right_context_length > 0:
            # (B * num_segments, D, right_context_length + kernel_size - 1)
            right_context_block = self._split_right_context(state_x_utterance.permute(2, 0, 1), x_right_context)
            conv_right_context_block = self.conv(right_context_block)  # (B * num_segments, D, right_context_length)
            # (T_right_context, B, D)
            conv_right_context = self._merge_right_context(conv_right_context_block, input.size(1))
            y = torch.cat([conv_right_context, conv_utterance], dim=0)
        else:
            y = conv_utterance

        output = self.post_conv(y) + input
        new_state = state_x_utterance[:, :, -self.state_size :]
        return output[right_context.size(0) :], output[: right_context.size(0)], new_state

    def infer(
        self, utterance: torch.Tensor, right_context: torch.Tensor, state: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        input = torch.cat((utterance, right_context))
        x = self.pre_conv(input)  # (T, B, D)
        x = x.permute(1, 2, 0)  # (B, D, T)

        if state is None:
            state = torch.zeros(
                input.size(1),
                input.size(2),
                self.state_size,
                device=input.device,
                dtype=input.dtype,
            )  # (B, D, T)
        state_x = torch.cat([state, x], dim=2)
        conv_out = self.conv(state_x)
        conv_out = conv_out.permute(2, 0, 1)  # T, B, D
        output = self.post_conv(conv_out) + input
        new_state = state_x[:, :, -self.state_size - right_context.size(0) : -right_context.size(0)]
        return output[: utterance.size(0)], output[utterance.size(0) :], new_state


class _ConvEmformerLayer(torch.nn.Module):
    r"""Convolution-augmented Emformer layer that constitutes ConvEmformer.

    Args:
        input_dim (int): input dimension.
        num_heads (int): number of attention heads.
        ffn_dim: (int): hidden layer dimension of feedforward network.
        segment_length (int): length of each input segment.
        kernel_size (int): size of kernel to use in convolution module.
        dropout (float, optional): dropout probability. (Default: 0.0)
        ffn_activation (str, optional): activation function to use in feedforward network.
            Must be one of ("relu", "gelu", "silu"). (Default: "relu")
        left_context_length (int, optional): length of left context. (Default: 0)
        right_context_length (int, optional): length of right context. (Default: 0)
        max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
        weight_init_gain (float or None, optional): scale factor to apply when initializing
            attention module parameters. (Default: ``None``)
        tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
        negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
        conv_activation (str, optional): activation function to use in convolution module.
            Must be one of ("relu", "gelu", "silu"). (Default: "silu")
    """

    def __init__(
        self,
        input_dim: int,
        num_heads: int,
        ffn_dim: int,
        segment_length: int,
        kernel_size: int,
        dropout: float = 0.0,
        ffn_activation: str = "relu",
        left_context_length: int = 0,
        right_context_length: int = 0,
        max_memory_size: int = 0,
        weight_init_gain: Optional[float] = None,
        tanh_on_mem: bool = False,
        negative_inf: float = -1e8,
        conv_activation: str = "silu",
    ):
        super().__init__()
        # TODO: implement talking heads attention.
        self.attention = _EmformerAttention(
            input_dim=input_dim,
            num_heads=num_heads,
            dropout=dropout,
            weight_init_gain=weight_init_gain,
            tanh_on_mem=tanh_on_mem,
            negative_inf=negative_inf,
        )
        self.dropout = torch.nn.Dropout(dropout)
        self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)

        activation_module = _get_activation_module(ffn_activation)
        self.ffn0 = _ResidualContainer(
            torch.nn.Sequential(
                torch.nn.LayerNorm(input_dim),
                torch.nn.Linear(input_dim, ffn_dim),
                activation_module,
                torch.nn.Dropout(dropout),
                torch.nn.Linear(ffn_dim, input_dim),
                torch.nn.Dropout(dropout),
            ),
            0.5,
        )
        self.ffn1 = _ResidualContainer(
            torch.nn.Sequential(
                torch.nn.LayerNorm(input_dim),
                torch.nn.Linear(input_dim, ffn_dim),
                activation_module,
                torch.nn.Dropout(dropout),
                torch.nn.Linear(ffn_dim, input_dim),
                torch.nn.Dropout(dropout),
            ),
            0.5,
        )
        self.layer_norm_input = torch.nn.LayerNorm(input_dim)
        self.layer_norm_output = torch.nn.LayerNorm(input_dim)

        self.conv = _ConvolutionModule(
            input_dim=input_dim,
            kernel_size=kernel_size,
            activation=conv_activation,
            dropout=dropout,
            segment_length=segment_length,
            right_context_length=right_context_length,
        )

        self.left_context_length = left_context_length
        self.segment_length = segment_length
        self.max_memory_size = max_memory_size
        self.input_dim = input_dim
        self.kernel_size = kernel_size
        self.use_mem = max_memory_size > 0

    def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
        empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
        left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
        left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
        past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
        conv_cache = torch.zeros(
            batch_size,
            self.input_dim,
            self.kernel_size - 1,
            device=device,
        )
        return [empty_memory, left_context_key, left_context_val, past_length, conv_cache]

    def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        past_length = state[3][0][0].item()
        past_left_context_length = min(self.left_context_length, past_length)
        past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
        pre_mems = state[0][self.max_memory_size - past_mem_length :]
        lc_key = state[1][self.left_context_length - past_left_context_length :]
        lc_val = state[2][self.left_context_length - past_left_context_length :]
        conv_cache = state[4]
        return pre_mems, lc_key, lc_val, conv_cache

    def _pack_state(
        self,
        next_k: torch.Tensor,
        next_v: torch.Tensor,
        update_length: int,
        mems: torch.Tensor,
        conv_cache: torch.Tensor,
        state: List[torch.Tensor],
    ) -> List[torch.Tensor]:
        new_k = torch.cat([state[1], next_k])
        new_v = torch.cat([state[2], next_v])
        state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
        state[1] = new_k[new_k.shape[0] - self.left_context_length :]
        state[2] = new_v[new_v.shape[0] - self.left_context_length :]
        state[3] = state[3] + update_length
        state[4] = conv_cache
        return state

    def _apply_pre_attention(
        self, utterance: torch.Tensor, right_context: torch.Tensor, summary: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        x = torch.cat([right_context, utterance, summary])
        ffn0_out = self.ffn0(x)
        layer_norm_input_out = self.layer_norm_input(ffn0_out)
        layer_norm_input_right_context, layer_norm_input_utterance, layer_norm_input_summary = (
            layer_norm_input_out[: right_context.size(0)],
            layer_norm_input_out[right_context.size(0) : right_context.size(0) + utterance.size(0)],
            layer_norm_input_out[right_context.size(0) + utterance.size(0) :],
        )
        return ffn0_out, layer_norm_input_right_context, layer_norm_input_utterance, layer_norm_input_summary

    def _apply_post_attention(
        self,
        rc_output: torch.Tensor,
        ffn0_out: torch.Tensor,
        conv_cache: Optional[torch.Tensor],
        rc_length: int,
        utterance_length: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        result = self.dropout(rc_output) + ffn0_out[: rc_length + utterance_length]
        conv_utterance, conv_right_context, conv_cache = self.conv(result[rc_length:], result[:rc_length], conv_cache)
        result = torch.cat([conv_right_context, conv_utterance])
        result = self.ffn1(result)
        result = self.layer_norm_output(result)
        output_utterance, output_right_context = result[rc_length:], result[:rc_length]
        return output_utterance, output_right_context, conv_cache

    def forward(
        self,
        utterance: torch.Tensor,
        lengths: torch.Tensor,
        right_context: torch.Tensor,
        mems: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Forward pass for training.

        B: batch size;
        D: feature dimension of each frame;
        T: number of utterance frames;
        R: number of right context frames;
        M: number of memory elements.

        Args:
            utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
            lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``utterance``.
            right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
            mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
            attention_mask (torch.Tensor): attention mask for underlying attention module.

        Returns:
            (Tensor, Tensor, Tensor):
                Tensor
                    encoded utterance frames, with shape `(T, B, D)`.
                Tensor
                    updated right context frames, with shape `(R, B, D)`.
                Tensor
                    updated memory elements, with shape `(M, B, D)`.
        """
        if self.use_mem:
            summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
        else:
            summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)

        (
            ffn0_out,
            layer_norm_input_right_context,
            layer_norm_input_utterance,
            layer_norm_input_summary,
        ) = self._apply_pre_attention(utterance, right_context, summary)

        rc_output, output_mems = self.attention(
            utterance=layer_norm_input_utterance,
            lengths=lengths,
            right_context=layer_norm_input_right_context,
            summary=layer_norm_input_summary,
            mems=mems,
            attention_mask=attention_mask,
        )

        output_utterance, output_right_context, _ = self._apply_post_attention(
            rc_output, ffn0_out, None, right_context.size(0), utterance.size(0)
        )

        return output_utterance, output_right_context, output_mems

    @torch.jit.export
    def infer(
        self,
        utterance: torch.Tensor,
        lengths: torch.Tensor,
        right_context: torch.Tensor,
        state: Optional[List[torch.Tensor]],
        mems: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
        r"""Forward pass for inference.

        B: batch size;
        D: feature dimension of each frame;
        T: number of utterance frames;
        R: number of right context frames;
        M: number of memory elements.

        Args:
            utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
            lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``utterance``.
            right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
            state (List[torch.Tensor] or None): list of tensors representing layer internal state
                generated in preceding invocation of ``infer``.
            mems (torch.Tensor): memory elements, with shape `(M, B, D)`.

        Returns:
            (Tensor, Tensor, List[torch.Tensor], Tensor):
                Tensor
                    encoded utterance frames, with shape `(T, B, D)`.
                Tensor
                    updated right context frames, with shape `(R, B, D)`.
                List[Tensor]
                    list of tensors representing layer internal state
                    generated in current invocation of ``infer``.
                Tensor
                    updated memory elements, with shape `(M, B, D)`.
        """
        if self.use_mem:
            summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:1]
        else:
            summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)

        (
            ffn0_out,
            layer_norm_input_right_context,
            layer_norm_input_utterance,
            layer_norm_input_summary,
        ) = self._apply_pre_attention(utterance, right_context, summary)

        if state is None:
            state = self._init_state(layer_norm_input_utterance.size(1), device=layer_norm_input_utterance.device)
        pre_mems, lc_key, lc_val, conv_cache = self._unpack_state(state)

        rc_output, next_m, next_k, next_v = self.attention.infer(
            utterance=layer_norm_input_utterance,
            lengths=lengths,
            right_context=layer_norm_input_right_context,
            summary=layer_norm_input_summary,
            mems=pre_mems,
            left_context_key=lc_key,
            left_context_val=lc_val,
        )

        output_utterance, output_right_context, conv_cache = self._apply_post_attention(
            rc_output, ffn0_out, conv_cache, right_context.size(0), utterance.size(0)
        )
        output_state = self._pack_state(next_k, next_v, utterance.size(0), mems, conv_cache, state)
        return output_utterance, output_right_context, output_state, next_m


[docs]class ConvEmformer(_EmformerImpl): r"""Implements the convolution-augmented streaming transformer architecture introduced in *Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution* :cite:`9747706`. Args: input_dim (int): input dimension. num_heads (int): number of attention heads in each ConvEmformer layer. ffn_dim (int): hidden layer dimension of each ConvEmformer layer's feedforward network. num_layers (int): number of ConvEmformer layers to instantiate. segment_length (int): length of each input segment. kernel_size (int): size of kernel to use in convolution modules. dropout (float, optional): dropout probability. (Default: 0.0) ffn_activation (str, optional): activation function to use in feedforward networks. Must be one of ("relu", "gelu", "silu"). (Default: "relu") left_context_length (int, optional): length of left context. (Default: 0) right_context_length (int, optional): length of right context. (Default: 0) max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0) weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise") tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``) negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8) conv_activation (str, optional): activation function to use in convolution modules. Must be one of ("relu", "gelu", "silu"). (Default: "silu") Examples: >>> conv_emformer = ConvEmformer(80, 4, 1024, 12, 16, 8, right_context_length=4) >>> input = torch.rand(10, 200, 80) >>> lengths = torch.randint(1, 200, (10,)) >>> output, lengths = conv_emformer(input, lengths) >>> input = torch.rand(4, 20, 80) >>> lengths = torch.ones(4) * 20 >>> output, lengths, states = conv_emformer.infer(input, lengths, None) """ def __init__( self, input_dim: int, num_heads: int, ffn_dim: int, num_layers: int, segment_length: int, kernel_size: int, dropout: float = 0.0, ffn_activation: str = "relu", left_context_length: int = 0, right_context_length: int = 0, max_memory_size: int = 0, weight_init_scale_strategy: Optional[str] = "depthwise", tanh_on_mem: bool = False, negative_inf: float = -1e8, conv_activation: str = "silu", ): weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers) emformer_layers = torch.nn.ModuleList( [ _ConvEmformerLayer( input_dim, num_heads, ffn_dim, segment_length, kernel_size, dropout=dropout, ffn_activation=ffn_activation, left_context_length=left_context_length, right_context_length=right_context_length, max_memory_size=max_memory_size, weight_init_gain=weight_init_gains[layer_idx], tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, conv_activation=conv_activation, ) for layer_idx in range(num_layers) ] ) super().__init__( emformer_layers, segment_length, left_context_length=left_context_length, right_context_length=right_context_length, max_memory_size=max_memory_size, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources