Emformer
- class torchaudio.models.Emformer(input_dim: int, num_heads: int, ffn_dim: int, num_layers: int, segment_length: int, dropout: float = 0.0, activation: str = 'relu', left_context_length: int = 0, right_context_length: int = 0, max_memory_size: int = 0, weight_init_scale_strategy: Optional[str] = 'depthwise', tanh_on_mem: bool = False, negative_inf: float = -100000000.0)[source]
Emformer architecture introduced in Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition [Shi et al., 2021].
See also
emformer_rnnt_model()
,emformer_rnnt_base()
: factory functions.torchaudio.pipelines.RNNTBundle
: ASR pipelines with pretrained model.
- Parameters:
input_dim (int) – input dimension.
num_heads (int) – number of attention heads in each Emformer layer.
ffn_dim (int) – hidden layer dimension of each Emformer layer’s feedforward network.
num_layers (int) – number of Emformer layers to instantiate.
segment_length (int) – length of each input segment.
dropout (float, optional) – dropout probability. (Default: 0.0)
activation (str, optional) – activation function to use in each Emformer layer’s feedforward network. Must be one of (“relu”, “gelu”, “silu”). (Default: “relu”)
left_context_length (int, optional) – length of left context. (Default: 0)
right_context_length (int, optional) – length of right context. (Default: 0)
max_memory_size (int, optional) – maximum number of memory elements to use. (Default: 0)
weight_init_scale_strategy (str or None, optional) – per-layer weight initialization scaling strategy. Must be one of (“depthwise”, “constant”,
None
). (Default: “depthwise”)tanh_on_mem (bool, optional) – if
True
, applies tanh to memory elements. (Default:False
)negative_inf (float, optional) – value to use for negative infinity in attention weights. (Default: -1e8)
Examples
>>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1) >>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim >>> lengths = torch.randint(1, 200, (128,)) # batch >>> output, lengths = emformer(input, lengths) >>> input = torch.rand(128, 5, 512) >>> lengths = torch.ones(128) * 5 >>> output, lengths, states = emformer.infer(input, lengths, None)
forward
- Emformer.forward(input: Tensor, lengths: Tensor) Tuple[Tensor, Tensor]
Forward pass for training and non-streaming inference.
B: batch size; T: max number of input frames in batch; D: feature dimension of each frame.
- Parameters:
input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, T + right_context_length, D).
lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in
input
.
- Returns:
- Tensor
output frames, with shape (B, T, D).
- Tensor
output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.
- Return type:
(Tensor, Tensor)
infer
- Emformer.infer(input: Tensor, lengths: Tensor, states: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, List[List[Tensor]]]
Forward pass for streaming inference.
B: batch size; D: feature dimension of each frame.
- Parameters:
input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, segment_length + right_context_length, D).
lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in
input
.states (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing internal state generated in preceding invocation of
infer
. (Default:None
)
- Returns:
- Tensor
output frames, with shape (B, segment_length, D).
- Tensor
output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.
- List[List[Tensor]]
output states; list of lists of tensors representing internal state generated in current invocation of
infer
.
- Return type:
(Tensor, Tensor, List[List[Tensor]])