Source code for torchaudio.models.rnnt
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
import torch
from torchaudio.models import Emformer
__all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
class _TimeReduction(torch.nn.Module):
r"""Coalesces frames along time dimension into a
fewer number of frames with higher feature dimensionality.
Args:
stride (int): number of frames to merge for each output frame.
"""
def __init__(self, stride: int) -> None:
super().__init__()
self.stride = stride
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass.
B: batch size;
T: maximum input sequence length in batch;
D: feature dimension of each input sequence frame.
Args:
input (torch.Tensor): input sequences, with shape `(B, T, D)`.
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``input``.
Returns:
(torch.Tensor, torch.Tensor):
torch.Tensor
output sequences, with shape
`(B, T // stride, D * stride)`
torch.Tensor
output lengths, with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in output sequences.
"""
B, T, D = input.shape
num_frames = T - (T % self.stride)
input = input[:, :num_frames, :]
lengths = lengths.div(self.stride, rounding_mode="trunc")
T_max = num_frames // self.stride
output = input.reshape(B, T_max, D * self.stride)
output = output.contiguous()
return output, lengths
class _CustomLSTM(torch.nn.Module):
r"""Custom long-short-term memory (LSTM) block that applies layer normalization
to internal nodes.
Args:
input_dim (int): input dimension.
hidden_dim (int): hidden dimension.
layer_norm (bool, optional): if ``True``, enables layer normalization. (Default: ``False``)
layer_norm_epsilon (float, optional): value of epsilon to use in
layer normalization layers (Default: 1e-5)
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
layer_norm: bool = False,
layer_norm_epsilon: float = 1e-5,
) -> None:
super().__init__()
self.x2g = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=(not layer_norm))
self.p2g = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=False)
if layer_norm:
self.c_norm = torch.nn.LayerNorm(hidden_dim, eps=layer_norm_epsilon)
self.g_norm = torch.nn.LayerNorm(4 * hidden_dim, eps=layer_norm_epsilon)
else:
self.c_norm = torch.nn.Identity()
self.g_norm = torch.nn.Identity()
self.hidden_dim = hidden_dim
def forward(
self, input: torch.Tensor, state: Optional[List[torch.Tensor]]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
r"""Forward pass.
B: batch size;
T: maximum sequence length in batch;
D: feature dimension of each input sequence element.
Args:
input (torch.Tensor): with shape `(T, B, D)`.
state (List[torch.Tensor] or None): list of tensors
representing internal state generated in preceding invocation
of ``forward``.
Returns:
(torch.Tensor, List[torch.Tensor]):
torch.Tensor
output, with shape `(T, B, hidden_dim)`.
List[torch.Tensor]
list of tensors representing internal state generated
in current invocation of ``forward``.
"""
if state is None:
B = input.size(1)
h = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
c = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
else:
h, c = state
gated_input = self.x2g(input)
outputs = []
for gates in gated_input.unbind(0):
gates = gates + self.p2g(h)
gates = self.g_norm(gates)
input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)
input_gate = input_gate.sigmoid()
forget_gate = forget_gate.sigmoid()
cell_gate = cell_gate.tanh()
output_gate = output_gate.sigmoid()
c = forget_gate * c + input_gate * cell_gate
c = self.c_norm(c)
h = output_gate * c.tanh()
outputs.append(h)
output = torch.stack(outputs, dim=0)
state = [h, c]
return output, state
class _Transcriber(ABC):
@abstractmethod
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
pass
@abstractmethod
def infer(
self,
input: torch.Tensor,
lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]],
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
pass
class _EmformerEncoder(torch.nn.Module, _Transcriber):
r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).
Args:
input_dim (int): feature dimension of each input sequence element.
output_dim (int): feature dimension of each output sequence element.
segment_length (int): length of input segment expressed as number of frames.
right_context_length (int): length of right context expressed as number of frames.
time_reduction_input_dim (int): dimension to scale each element in input sequences to
prior to applying time reduction block.
time_reduction_stride (int): factor by which to reduce length of input sequence.
transformer_num_heads (int): number of attention heads in each Emformer layer.
transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
transformer_num_layers (int): number of Emformer layers to instantiate.
transformer_left_context_length (int): length of left context.
transformer_dropout (float, optional): transformer dropout probability. (Default: 0.0)
transformer_activation (str, optional): activation function to use in each Emformer layer's
feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
transformer_max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
transformer_weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
transformer_tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
"""
def __init__(
self,
*,
input_dim: int,
output_dim: int,
segment_length: int,
right_context_length: int,
time_reduction_input_dim: int,
time_reduction_stride: int,
transformer_num_heads: int,
transformer_ffn_dim: int,
transformer_num_layers: int,
transformer_left_context_length: int,
transformer_dropout: float = 0.0,
transformer_activation: str = "relu",
transformer_max_memory_size: int = 0,
transformer_weight_init_scale_strategy: str = "depthwise",
transformer_tanh_on_mem: bool = False,
) -> None:
super().__init__()
self.input_linear = torch.nn.Linear(
input_dim,
time_reduction_input_dim,
bias=False,
)
self.time_reduction = _TimeReduction(time_reduction_stride)
transformer_input_dim = time_reduction_input_dim * time_reduction_stride
self.transformer = Emformer(
transformer_input_dim,
transformer_num_heads,
transformer_ffn_dim,
transformer_num_layers,
segment_length // time_reduction_stride,
dropout=transformer_dropout,
activation=transformer_activation,
left_context_length=transformer_left_context_length,
right_context_length=right_context_length // time_reduction_stride,
max_memory_size=transformer_max_memory_size,
weight_init_scale_strategy=transformer_weight_init_scale_strategy,
tanh_on_mem=transformer_tanh_on_mem,
)
self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
self.layer_norm = torch.nn.LayerNorm(output_dim)
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Forward pass for training.
B: batch size;
T: maximum input sequence length in batch;
D: feature dimension of each input sequence frame (input_dim).
Args:
input (torch.Tensor): input frame sequences right-padded with right context, with
shape `(B, T + right context length, D)`.
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``input``.
Returns:
(torch.Tensor, torch.Tensor):
torch.Tensor
output frame sequences, with
shape `(B, T // time_reduction_stride, output_dim)`.
torch.Tensor
output input lengths, with shape `(B,)` and i-th element representing
number of valid elements for i-th batch element in output frame sequences.
"""
input_linear_out = self.input_linear(input)
time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
output_linear_out = self.output_linear(transformer_out)
layer_norm_out = self.layer_norm(output_linear_out)
return layer_norm_out, transformer_lengths
@torch.jit.export
def infer(
self,
input: torch.Tensor,
lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]],
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
r"""Forward pass for inference.
B: batch size;
T: maximum input sequence segment length in batch;
D: feature dimension of each input sequence frame (input_dim).
Args:
input (torch.Tensor): input frame sequence segments right-padded with right context, with
shape `(B, T + right context length, D)`.
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``input``.
state (List[List[torch.Tensor]] or None): list of lists of tensors
representing internal state generated in preceding invocation
of ``infer``.
Returns:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
torch.Tensor
output frame sequences, with
shape `(B, T // time_reduction_stride, output_dim)`.
torch.Tensor
output input lengths, with shape `(B,)` and i-th element representing
number of valid elements for i-th batch element in output.
List[List[torch.Tensor]]
output states; list of lists of tensors
representing internal state generated in current invocation
of ``infer``.
"""
input_linear_out = self.input_linear(input)
time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
(
transformer_out,
transformer_lengths,
transformer_states,
) = self.transformer.infer(time_reduction_out, time_reduction_lengths, states)
output_linear_out = self.output_linear(transformer_out)
layer_norm_out = self.layer_norm(output_linear_out)
return layer_norm_out, transformer_lengths, transformer_states
class _Predictor(torch.nn.Module):
r"""Recurrent neural network transducer (RNN-T) prediction network.
Args:
num_symbols (int): size of target token lexicon.
output_dim (int): feature dimension of each output sequence element.
symbol_embedding_dim (int): dimension of each target token embedding.
num_lstm_layers (int): number of LSTM layers to instantiate.
lstm_hidden_dim (int): output dimension of each LSTM layer.
lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
for LSTM layers. (Default: ``False``)
lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
LSTM layer normalization layers. (Default: 1e-5)
lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)
"""
def __init__(
self,
num_symbols: int,
output_dim: int,
symbol_embedding_dim: int,
num_lstm_layers: int,
lstm_hidden_dim: int,
lstm_layer_norm: bool = False,
lstm_layer_norm_epsilon: float = 1e-5,
lstm_dropout: float = 0.0,
) -> None:
super().__init__()
self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim)
self.input_layer_norm = torch.nn.LayerNorm(symbol_embedding_dim)
self.lstm_layers = torch.nn.ModuleList(
[
_CustomLSTM(
symbol_embedding_dim if idx == 0 else lstm_hidden_dim,
lstm_hidden_dim,
layer_norm=lstm_layer_norm,
layer_norm_epsilon=lstm_layer_norm_epsilon,
)
for idx in range(num_lstm_layers)
]
)
self.dropout = torch.nn.Dropout(p=lstm_dropout)
self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim)
self.output_layer_norm = torch.nn.LayerNorm(output_dim)
self.lstm_dropout = lstm_dropout
def forward(
self,
input: torch.Tensor,
lengths: torch.Tensor,
state: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
r"""Forward pass.
B: batch size;
U: maximum sequence length in batch;
D: feature dimension of each input sequence element.
Args:
input (torch.Tensor): target sequences, with shape `(B, U)` and each element
mapping to a target symbol, i.e. in range `[0, num_symbols)`.
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``input``.
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
representing internal state generated in preceding invocation
of ``forward``. (Default: ``None``)
Returns:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
torch.Tensor
output encoding sequences, with shape `(B, U, output_dim)`
torch.Tensor
output lengths, with shape `(B,)` and i-th element representing
number of valid elements for i-th batch element in output encoding sequences.
List[List[torch.Tensor]]
output states; list of lists of tensors
representing internal state generated in current invocation of ``forward``.
"""
input_tb = input.permute(1, 0)
embedding_out = self.embedding(input_tb)
input_layer_norm_out = self.input_layer_norm(embedding_out)
lstm_out = input_layer_norm_out
state_out: List[List[torch.Tensor]] = []
for layer_idx, lstm in enumerate(self.lstm_layers):
lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx])
lstm_out = self.dropout(lstm_out)
state_out.append(lstm_state_out)
linear_out = self.linear(lstm_out)
output_layer_norm_out = self.output_layer_norm(linear_out)
return output_layer_norm_out.permute(1, 0, 2), lengths, state_out
class _Joiner(torch.nn.Module):
r"""Recurrent neural network transducer (RNN-T) joint network.
Args:
input_dim (int): source and target input dimension.
output_dim (int): output dimension.
activation (str, optional): activation function to use in the joiner.
Must be one of ("relu", "tanh"). (Default: "relu")
"""
def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
super().__init__()
self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
if activation == "relu":
self.activation = torch.nn.ReLU()
elif activation == "tanh":
self.activation = torch.nn.Tanh()
else:
raise ValueError(f"Unsupported activation {activation}")
def forward(
self,
source_encodings: torch.Tensor,
source_lengths: torch.Tensor,
target_encodings: torch.Tensor,
target_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Forward pass for training.
B: batch size;
T: maximum source sequence length in batch;
U: maximum target sequence length in batch;
D: dimension of each source and target sequence encoding.
Args:
source_encodings (torch.Tensor): source encoding sequences, with
shape `(B, T, D)`.
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
valid sequence length of i-th batch element in ``source_encodings``.
target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
valid sequence length of i-th batch element in ``target_encodings``.
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor):
torch.Tensor
joint network output, with shape `(B, T, U, output_dim)`.
torch.Tensor
output source lengths, with shape `(B,)` and i-th element representing
number of valid elements along dim 1 for i-th batch element in joint network output.
torch.Tensor
output target lengths, with shape `(B,)` and i-th element representing
number of valid elements along dim 2 for i-th batch element in joint network output.
"""
joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
activation_out = self.activation(joint_encodings)
output = self.linear(activation_out)
return output, source_lengths, target_lengths
[docs]class RNNT(torch.nn.Module):
r"""torchaudio.models.RNNT()
Recurrent neural network transducer (RNN-T) model.
Note:
To build the model, please use one of the factory functions.
See Also:
:class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pre-trained models.
Args:
transcriber (torch.nn.Module): transcription network.
predictor (torch.nn.Module): prediction network.
joiner (torch.nn.Module): joint network.
"""
def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None:
super().__init__()
self.transcriber = transcriber
self.predictor = predictor
self.joiner = joiner
[docs] def forward(
self,
sources: torch.Tensor,
source_lengths: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
predictor_state: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
r"""Forward pass for training.
B: batch size;
T: maximum source sequence length in batch;
U: maximum target sequence length in batch;
D: feature dimension of each source sequence element.
Args:
sources (torch.Tensor): source frame sequences right-padded with right context, with
shape `(B, T, D)`.
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``sources``.
targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
mapping to a target symbol.
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``targets``.
predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
representing prediction network internal state generated in preceding invocation
of ``forward``. (Default: ``None``)
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
torch.Tensor
joint network output, with shape
`(B, max output source length, max output target length, output_dim (number of target symbols))`.
torch.Tensor
output source lengths, with shape `(B,)` and i-th element representing
number of valid elements along dim 1 for i-th batch element in joint network output.
torch.Tensor
output target lengths, with shape `(B,)` and i-th element representing
number of valid elements along dim 2 for i-th batch element in joint network output.
List[List[torch.Tensor]]
output states; list of lists of tensors
representing prediction network internal state generated in current invocation
of ``forward``.
"""
source_encodings, source_lengths = self.transcriber(
input=sources,
lengths=source_lengths,
)
target_encodings, target_lengths, predictor_state = self.predictor(
input=targets,
lengths=target_lengths,
state=predictor_state,
)
output, source_lengths, target_lengths = self.joiner(
source_encodings=source_encodings,
source_lengths=source_lengths,
target_encodings=target_encodings,
target_lengths=target_lengths,
)
return (
output,
source_lengths,
target_lengths,
predictor_state,
)
[docs] @torch.jit.export
def transcribe_streaming(
self,
sources: torch.Tensor,
source_lengths: torch.Tensor,
state: Optional[List[List[torch.Tensor]]],
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
r"""Applies transcription network to sources in streaming mode.
B: batch size;
T: maximum source sequence segment length in batch;
D: feature dimension of each source sequence frame.
Args:
sources (torch.Tensor): source frame sequence segments right-padded with right context, with
shape `(B, T + right context length, D)`.
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``sources``.
state (List[List[torch.Tensor]] or None): list of lists of tensors
representing transcription network internal state generated in preceding invocation
of ``transcribe_streaming``.
Returns:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
torch.Tensor
output frame sequences, with
shape `(B, T // time_reduction_stride, output_dim)`.
torch.Tensor
output lengths, with shape `(B,)` and i-th element representing
number of valid elements for i-th batch element in output.
List[List[torch.Tensor]]
output states; list of lists of tensors
representing transcription network internal state generated in current invocation
of ``transcribe_streaming``.
"""
return self.transcriber.infer(sources, source_lengths, state)
[docs] @torch.jit.export
def transcribe(
self,
sources: torch.Tensor,
source_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Applies transcription network to sources in non-streaming mode.
B: batch size;
T: maximum source sequence length in batch;
D: feature dimension of each source sequence frame.
Args:
sources (torch.Tensor): source frame sequences right-padded with right context, with
shape `(B, T + right context length, D)`.
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``sources``.
Returns:
(torch.Tensor, torch.Tensor):
torch.Tensor
output frame sequences, with
shape `(B, T // time_reduction_stride, output_dim)`.
torch.Tensor
output lengths, with shape `(B,)` and i-th element representing
number of valid elements for i-th batch element in output frame sequences.
"""
return self.transcriber(sources, source_lengths)
[docs] @torch.jit.export
def predict(
self,
targets: torch.Tensor,
target_lengths: torch.Tensor,
state: Optional[List[List[torch.Tensor]]],
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
r"""Applies prediction network to targets.
B: batch size;
U: maximum target sequence length in batch;
D: feature dimension of each target sequence frame.
Args:
targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
mapping to a target symbol, i.e. in range `[0, num_symbols)`.
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``targets``.
state (List[List[torch.Tensor]] or None): list of lists of tensors
representing internal state generated in preceding invocation
of ``predict``.
Returns:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
torch.Tensor
output frame sequences, with shape `(B, U, output_dim)`.
torch.Tensor
output lengths, with shape `(B,)` and i-th element representing
number of valid elements for i-th batch element in output.
List[List[torch.Tensor]]
output states; list of lists of tensors
representing internal state generated in current invocation of ``predict``.
"""
return self.predictor(input=targets, lengths=target_lengths, state=state)
[docs] @torch.jit.export
def join(
self,
source_encodings: torch.Tensor,
source_lengths: torch.Tensor,
target_encodings: torch.Tensor,
target_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Applies joint network to source and target encodings.
B: batch size;
T: maximum source sequence length in batch;
U: maximum target sequence length in batch;
D: dimension of each source and target sequence encoding.
Args:
source_encodings (torch.Tensor): source encoding sequences, with
shape `(B, T, D)`.
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
valid sequence length of i-th batch element in ``source_encodings``.
target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
valid sequence length of i-th batch element in ``target_encodings``.
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor):
torch.Tensor
joint network output, with shape `(B, T, U, output_dim)`.
torch.Tensor
output source lengths, with shape `(B,)` and i-th element representing
number of valid elements along dim 1 for i-th batch element in joint network output.
torch.Tensor
output target lengths, with shape `(B,)` and i-th element representing
number of valid elements along dim 2 for i-th batch element in joint network output.
"""
output, source_lengths, target_lengths = self.joiner(
source_encodings=source_encodings,
source_lengths=source_lengths,
target_encodings=target_encodings,
target_lengths=target_lengths,
)
return output, source_lengths, target_lengths
[docs]def emformer_rnnt_model(
*,
input_dim: int,
encoding_dim: int,
num_symbols: int,
segment_length: int,
right_context_length: int,
time_reduction_input_dim: int,
time_reduction_stride: int,
transformer_num_heads: int,
transformer_ffn_dim: int,
transformer_num_layers: int,
transformer_dropout: float,
transformer_activation: str,
transformer_left_context_length: int,
transformer_max_memory_size: int,
transformer_weight_init_scale_strategy: str,
transformer_tanh_on_mem: bool,
symbol_embedding_dim: int,
num_lstm_layers: int,
lstm_layer_norm: bool,
lstm_layer_norm_epsilon: float,
lstm_dropout: float,
) -> RNNT:
r"""Builds Emformer-based :class:`~torchaudio.models.RNNT`.
Note:
For non-streaming inference, the expectation is for `transcribe` to be called on input
sequences right-concatenated with `right_context_length` frames.
For streaming inference, the expectation is for `transcribe_streaming` to be called
on input chunks comprising `segment_length` frames right-concatenated with `right_context_length`
frames.
Args:
input_dim (int): dimension of input sequence frames passed to transcription network.
encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
passed to joint network.
num_symbols (int): cardinality of set of target tokens.
segment_length (int): length of input segment expressed as number of frames.
right_context_length (int): length of right context expressed as number of frames.
time_reduction_input_dim (int): dimension to scale each element in input sequences to
prior to applying time reduction block.
time_reduction_stride (int): factor by which to reduce length of input sequence.
transformer_num_heads (int): number of attention heads in each Emformer layer.
transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
transformer_num_layers (int): number of Emformer layers to instantiate.
transformer_left_context_length (int): length of left context considered by Emformer.
transformer_dropout (float): Emformer dropout probability.
transformer_activation (str): activation function to use in each Emformer layer's
feedforward network. Must be one of ("relu", "gelu", "silu").
transformer_max_memory_size (int): maximum number of memory elements to use.
transformer_weight_init_scale_strategy (str): per-layer weight initialization scaling
strategy. Must be one of ("depthwise", "constant", ``None``).
transformer_tanh_on_mem (bool): if ``True``, applies tanh to memory elements.
symbol_embedding_dim (int): dimension of each target token embedding.
num_lstm_layers (int): number of LSTM layers to instantiate.
lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
lstm_dropout (float): LSTM dropout probability.
Returns:
RNNT:
Emformer RNN-T model.
"""
encoder = _EmformerEncoder(
input_dim=input_dim,
output_dim=encoding_dim,
segment_length=segment_length,
right_context_length=right_context_length,
time_reduction_input_dim=time_reduction_input_dim,
time_reduction_stride=time_reduction_stride,
transformer_num_heads=transformer_num_heads,
transformer_ffn_dim=transformer_ffn_dim,
transformer_num_layers=transformer_num_layers,
transformer_dropout=transformer_dropout,
transformer_activation=transformer_activation,
transformer_left_context_length=transformer_left_context_length,
transformer_max_memory_size=transformer_max_memory_size,
transformer_weight_init_scale_strategy=transformer_weight_init_scale_strategy,
transformer_tanh_on_mem=transformer_tanh_on_mem,
)
predictor = _Predictor(
num_symbols,
encoding_dim,
symbol_embedding_dim=symbol_embedding_dim,
num_lstm_layers=num_lstm_layers,
lstm_hidden_dim=symbol_embedding_dim,
lstm_layer_norm=lstm_layer_norm,
lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
lstm_dropout=lstm_dropout,
)
joiner = _Joiner(encoding_dim, num_symbols)
return RNNT(encoder, predictor, joiner)
[docs]def emformer_rnnt_base(num_symbols: int) -> RNNT:
r"""Builds basic version of Emformer-based :class:`~torchaudio.models.RNNT`.
Args:
num_symbols (int): The size of target token lexicon.
Returns:
RNNT:
Emformer RNN-T model.
"""
return emformer_rnnt_model(
input_dim=80,
encoding_dim=1024,
num_symbols=num_symbols,
segment_length=16,
right_context_length=4,
time_reduction_input_dim=128,
time_reduction_stride=4,
transformer_num_heads=8,
transformer_ffn_dim=2048,
transformer_num_layers=20,
transformer_dropout=0.1,
transformer_activation="gelu",
transformer_left_context_length=30,
transformer_max_memory_size=0,
transformer_weight_init_scale_strategy="depthwise",
transformer_tanh_on_mem=True,
symbol_embedding_dim=512,
num_lstm_layers=3,
lstm_layer_norm=True,
lstm_layer_norm_epsilon=1e-3,
lstm_dropout=0.3,
)