ConvEmformer¶
- class torchaudio.prototype.models.ConvEmformer(input_dim: int, num_heads: int, ffn_dim: int, num_layers: int, segment_length: int, kernel_size: int, dropout: float = 0.0, ffn_activation: str = 'relu', left_context_length: int = 0, right_context_length: int = 0, max_memory_size: int = 0, weight_init_scale_strategy: Optional[str] = 'depthwise', tanh_on_mem: bool = False, negative_inf: float = -100000000.0, conv_activation: str = 'silu')[source]¶
Implements the convolution-augmented streaming transformer architecture introduced in Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution [Shi et al., 2022].
- Parameters:
input_dim (int) – input dimension.
num_heads (int) – number of attention heads in each ConvEmformer layer.
ffn_dim (int) – hidden layer dimension of each ConvEmformer layer’s feedforward network.
num_layers (int) – number of ConvEmformer layers to instantiate.
segment_length (int) – length of each input segment.
kernel_size (int) – size of kernel to use in convolution modules.
dropout (float, optional) – dropout probability. (Default: 0.0)
ffn_activation (str, optional) – activation function to use in feedforward networks. Must be one of (“relu”, “gelu”, “silu”). (Default: “relu”)
left_context_length (int, optional) – length of left context. (Default: 0)
right_context_length (int, optional) – length of right context. (Default: 0)
max_memory_size (int, optional) – maximum number of memory elements to use. (Default: 0)
weight_init_scale_strategy (str or None, optional) – per-layer weight initialization scaling strategy. Must be one of (“depthwise”, “constant”,
None
). (Default: “depthwise”)tanh_on_mem (bool, optional) – if
True
, applies tanh to memory elements. (Default:False
)negative_inf (float, optional) – value to use for negative infinity in attention weights. (Default: -1e8)
conv_activation (str, optional) – activation function to use in convolution modules. Must be one of (“relu”, “gelu”, “silu”). (Default: “silu”)
Examples
>>> conv_emformer = ConvEmformer(80, 4, 1024, 12, 16, 8, right_context_length=4) >>> input = torch.rand(10, 200, 80) >>> lengths = torch.randint(1, 200, (10,)) >>> output, lengths = conv_emformer(input, lengths) >>> input = torch.rand(4, 20, 80) >>> lengths = torch.ones(4) * 20 >>> output, lengths, states = conv_emformer.infer(input, lengths, None)
Methods¶
forward¶
- ConvEmformer.forward(input: Tensor, lengths: Tensor) Tuple[Tensor, Tensor] ¶
Forward pass for training and non-streaming inference.
B: batch size; T: max number of input frames in batch; D: feature dimension of each frame.
- Parameters:
input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, T + right_context_length, D).
lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in
input
.
- Returns:
- Tensor
output frames, with shape (B, T, D).
- Tensor
output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.
- Return type:
(Tensor, Tensor)
infer¶
- ConvEmformer.infer(input: Tensor, lengths: Tensor, states: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, List[List[Tensor]]] ¶
Forward pass for streaming inference.
B: batch size; D: feature dimension of each frame.
- Parameters:
input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, segment_length + right_context_length, D).
lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in
input
.states (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing internal state generated in preceding invocation of
infer
. (Default:None
)
- Returns:
- Tensor
output frames, with shape (B, segment_length, D).
- Tensor
output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.
- List[List[Tensor]]
output states; list of lists of tensors representing internal state generated in current invocation of
infer
.
- Return type:
(Tensor, Tensor, List[List[Tensor]])