Shortcuts

ConvEmformer

class torchaudio.prototype.models.ConvEmformer(input_dim: int, num_heads: int, ffn_dim: int, num_layers: int, segment_length: int, kernel_size: int, dropout: float = 0.0, ffn_activation: str = 'relu', left_context_length: int = 0, right_context_length: int = 0, max_memory_size: int = 0, weight_init_scale_strategy: Optional[str] = 'depthwise', tanh_on_mem: bool = False, negative_inf: float = -100000000.0, conv_activation: str = 'silu')[source]

Implements the convolution-augmented streaming transformer architecture introduced in Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution [Shi et al., 2022].

Parameters:
  • input_dim (int) – input dimension.

  • num_heads (int) – number of attention heads in each ConvEmformer layer.

  • ffn_dim (int) – hidden layer dimension of each ConvEmformer layer’s feedforward network.

  • num_layers (int) – number of ConvEmformer layers to instantiate.

  • segment_length (int) – length of each input segment.

  • kernel_size (int) – size of kernel to use in convolution modules.

  • dropout (float, optional) – dropout probability. (Default: 0.0)

  • ffn_activation (str, optional) – activation function to use in feedforward networks. Must be one of (“relu”, “gelu”, “silu”). (Default: “relu”)

  • left_context_length (int, optional) – length of left context. (Default: 0)

  • right_context_length (int, optional) – length of right context. (Default: 0)

  • max_memory_size (int, optional) – maximum number of memory elements to use. (Default: 0)

  • weight_init_scale_strategy (str or None, optional) – per-layer weight initialization scaling strategy. Must be one of (“depthwise”, “constant”, None). (Default: “depthwise”)

  • tanh_on_mem (bool, optional) – if True, applies tanh to memory elements. (Default: False)

  • negative_inf (float, optional) – value to use for negative infinity in attention weights. (Default: -1e8)

  • conv_activation (str, optional) – activation function to use in convolution modules. Must be one of (“relu”, “gelu”, “silu”). (Default: “silu”)

Examples

>>> conv_emformer = ConvEmformer(80, 4, 1024, 12, 16, 8, right_context_length=4)
>>> input = torch.rand(10, 200, 80)
>>> lengths = torch.randint(1, 200, (10,))
>>> output, lengths = conv_emformer(input, lengths)
>>> input = torch.rand(4, 20, 80)
>>> lengths = torch.ones(4) * 20
>>> output, lengths, states = conv_emformer.infer(input, lengths, None)

Methods

forward

ConvEmformer.forward(input: Tensor, lengths: Tensor) Tuple[Tensor, Tensor]

Forward pass for training and non-streaming inference.

B: batch size; T: max number of input frames in batch; D: feature dimension of each frame.

Parameters:
  • input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, T + right_context_length, D).

  • lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in input.

Returns:

Tensor

output frames, with shape (B, T, D).

Tensor

output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.

Return type:

(Tensor, Tensor)

infer

ConvEmformer.infer(input: Tensor, lengths: Tensor, states: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, List[List[Tensor]]]

Forward pass for streaming inference.

B: batch size; D: feature dimension of each frame.

Parameters:
  • input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, segment_length + right_context_length, D).

  • lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in input.

  • states (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing internal state generated in preceding invocation of infer. (Default: None)

Returns:

Tensor

output frames, with shape (B, segment_length, D).

Tensor

output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.

List[List[Tensor]]

output states; list of lists of tensors representing internal state generated in current invocation of infer.

Return type:

(Tensor, Tensor, List[List[Tensor]])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources