Shortcuts

torchaudio.prototype

torchaudio.prototype provides prototype features; see here for more information on prototype features. The module is available only within nightly builds and must be imported explicitly, e.g. import torchaudio.prototype.

Emformer

class torchaudio.prototype.Emformer(input_dim: int, num_heads: int, ffn_dim: int, num_layers: int, dropout: float = 0.0, activation: str = 'relu', left_context_length: int = 0, right_context_length: int = 0, segment_length: int = 128, max_memory_size: int = 0, weight_init_scale_strategy: str = 'depthwise', tanh_on_mem: bool = False, negative_inf: float = - 100000000.0)[source]

Implements the Emformer architecture introduced in Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition [1].

Parameters
  • input_dim (int) – input dimension.

  • num_heads (int) – number of attention heads in each Emformer layer.

  • ffn_dim (int) – hidden layer dimension of each Emformer layer’s feedforward network.

  • num_layers (int) – number of Emformer layers to instantiate.

  • dropout (float, optional) – dropout probability. (Default: 0.0)

  • activation (str, optional) – activation function to use in each Emformer layer’s feedforward network. Must be one of (“relu”, “gelu”, “silu”). (Default: “relu”)

  • left_context_length (int, optional) – length of left context. (Default: 0)

  • right_context_length (int, optional) – length of right context. (Default: 0)

  • segment_length (int, optional) – length of each input segment. (Default: 128)

  • max_memory_size (int, optional) – maximum number of memory elements to use. (Default: 0)

  • weight_init_scale_strategy (str, optional) – per-layer weight initialization scaling strategy. Must be one of (“depthwise”, “constant”, None). (Default: “depthwise”)

  • tanh_on_mem (bool, optional) – if True, applies tanh to memory elements. (Default: False)

  • negative_inf (float, optional) – value to use for negative infinity in attention weights. (Default: -1e8)

Examples

>>> emformer = Emformer(512, 8, 2048, 20)
>>> input = torch.rand(128, 400, 512)  # batch, num_frames, feature_dim
>>> lengths = torch.randint(1, 200, (128,))  # batch
>>> output = emformer(input, lengths)
>>> output, lengths, states = emformer.infer(input, lengths, None)
forward(input: torch.Tensor, lengths: torch.Tensor)Tuple[torch.Tensor, torch.Tensor][source]

Forward pass for training.

B: batch size; T: number of frames; D: feature dimension of each frame.

Parameters
  • input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, T, D).

  • lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in input.

Returns

Tensor

output frames, with shape (B, T - ``right_context_length`, D)`.

Tensor

output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.

Return type

(Tensor, Tensor)

infer(input: torch.Tensor, lengths: torch.Tensor, states: Optional[List[List[torch.Tensor]]] = None)Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]][source]

Forward pass for inference.

B: batch size; T: number of frames; D: feature dimension of each frame.

Parameters
  • input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, T, D).

  • lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in input.

  • states (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing Emformer internal state generated in preceding invocation of infer. (Default: None)

Returns

Tensor

output frames, with shape (B, T - ``right_context_length`, D)`.

Tensor

output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.

List[List[Tensor]]

output states; list of lists of tensors representing Emformer internal state generated in current invocation of infer.

Return type

(Tensor, Tensor, List[List[Tensor]])

RNNT

class torchaudio.prototype.RNNT(transcriber: torchaudio.prototype.rnnt._Transcriber, predictor: torchaudio.prototype.rnnt._Predictor, joiner: torchaudio.prototype.rnnt._Joiner)[source]

Recurrent neural network transducer (RNN-T) model.

Note

To build the model, please use one of the factory functions.

Parameters
forward(sources: torch.Tensor, source_lengths: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, predictor_state: Optional[List[List[torch.Tensor]]] = None)Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]][source]

Forward pass for training.

B: batch size; T: maximum source sequence length in batch; U: maximum target sequence length in batch; D: feature dimension of each source sequence element.

Parameters
  • sources (torch.Tensor) – source frame sequences right-padded with right context, with shape (B, T, D).

  • source_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in sources.

  • targets (torch.Tensor) – target sequences, with shape (B, U) and each element mapping to a target symbol.

  • target_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in targets.

  • predictor_state (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing prediction network internal state generated in preceding invocation of forward. (Default: None)

Returns

torch.Tensor

joint network output, with shape (B, max output source length, max output target length, number of target symbols).

torch.Tensor

output source lengths, with shape (B,) and i-th element representing number of valid elements along dim 1 for i-th batch element in joint network output.

torch.Tensor

output target lengths, with shape (B,) and i-th element representing number of valid elements along dim 2 for i-th batch element in joint network output.

List[List[torch.Tensor]]

output states; list of lists of tensors representing prediction network internal state generated in current invocation of forward.

Return type

(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

transcribe_streaming(sources: torch.Tensor, source_lengths: torch.Tensor, state: Optional[List[List[torch.Tensor]]])Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]][source]

Applies transcription network to sources in streaming mode.

B: batch size; T: maximum source sequence segment length in batch; D: feature dimension of each source sequence frame.

Parameters
  • sources (torch.Tensor) – source frame sequence segments right-padded with right context, with shape (B, T + right context length, D).

  • source_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in sources.

  • state (List[List[torch.Tensor]] or None) – list of lists of tensors representing transcription network internal state generated in preceding invocation of transcribe_streaming.

Returns

torch.Tensor

output frame sequences, with shape (B, T // time_reduction_stride, output_dim).

torch.Tensor

output lengths, with shape (B,) and i-th element representing number of valid elements for i-th batch element in output.

List[List[torch.Tensor]]

output states; list of lists of tensors representing transcription network internal state generated in current invocation of transcribe_streaming.

Return type

(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

transcribe(sources: torch.Tensor, source_lengths: torch.Tensor)Tuple[torch.Tensor, torch.Tensor][source]

Applies transcription network to sources in non-streaming mode.

B: batch size; T: maximum source sequence length in batch; D: feature dimension of each source sequence frame.

Parameters
  • sources (torch.Tensor) – source frame sequences right-padded with right context, with shape (B, T + right context length, D).

  • source_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in sources.

Returns

torch.Tensor

output frame sequences, with shape (B, T // time_reduction_stride, output_dim).

torch.Tensor

output lengths, with shape (B,) and i-th element representing number of valid elements for i-th batch element in output frame sequences.

Return type

(torch.Tensor, torch.Tensor)

predict(targets: torch.Tensor, target_lengths: torch.Tensor, state: Optional[List[List[torch.Tensor]]])Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]][source]

Applies prediction network to targets.

B: batch size; U: maximum target sequence length in batch; D: feature dimension of each target sequence frame.

Parameters
  • targets (torch.Tensor) – target sequences, with shape (B, U) and each element mapping to a target symbol, i.e. in range [0, num_symbols).

  • target_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in targets.

  • state (List[List[torch.Tensor]] or None) – list of lists of tensors representing internal state generated in preceding invocation of predict.

Returns

torch.Tensor

output frame sequences, with shape (B, U, output_dim).

torch.Tensor

output lengths, with shape (B,) and i-th element representing number of valid elements for i-th batch element in output.

List[List[torch.Tensor]]

output states; list of lists of tensors representing internal state generated in current invocation of predict.

Return type

(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

join(source_encodings: torch.Tensor, source_lengths: torch.Tensor, target_encodings: torch.Tensor, target_lengths: torch.Tensor)Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

Applies joint network to source and target encodings.

B: batch size; T: maximum source sequence length in batch; U: maximum target sequence length in batch; D: dimension of each source and target sequence encoding.

Parameters
  • source_encodings (torch.Tensor) – source encoding sequences, with shape (B, T, D).

  • source_lengths (torch.Tensor) – with shape (B,) and i-th element representing valid sequence length of i-th batch element in source_encodings.

  • target_encodings (torch.Tensor) – target encoding sequences, with shape (B, U, D).

  • target_lengths (torch.Tensor) – with shape (B,) and i-th element representing valid sequence length of i-th batch element in target_encodings.

Returns

torch.Tensor

joint network output, with shape (B, T, U, D).

torch.Tensor

output source lengths, with shape (B,) and i-th element representing number of valid elements along dim 1 for i-th batch element in joint network output.

torch.Tensor

output target lengths, with shape (B,) and i-th element representing number of valid elements along dim 2 for i-th batch element in joint network output.

Return type

(torch.Tensor, torch.Tensor, torch.Tensor)

emformer_rnnt_base

torchaudio.prototype.emformer_rnnt_base()torchaudio.prototype.rnnt.RNNT[source]

Builds basic version of Emformer RNN-T model.

Returns

Emformer RNN-T model.

Return type

RNNT

emformer_rnnt_model

torchaudio.prototype.emformer_rnnt_model(*, input_dim: int, encoding_dim: int, num_symbols: int, segment_length: int, right_context_length: int, time_reduction_input_dim: int, time_reduction_stride: int, transformer_num_heads: int, transformer_ffn_dim: int, transformer_num_layers: int, transformer_dropout: float, transformer_activation: str, transformer_left_context_length: int, transformer_max_memory_size: int, transformer_weight_init_scale_strategy: str, transformer_tanh_on_mem: bool, symbol_embedding_dim: int, num_lstm_layers: int, lstm_layer_norm: bool, lstm_layer_norm_epsilon: float, lstm_dropout: float)torchaudio.prototype.rnnt.RNNT[source]

Builds Emformer-based recurrent neural network transducer (RNN-T) model.

Note

For non-streaming inference, the expectation is for transcribe to be called on input sequences right-concatenated with right_context_length frames.

For streaming inference, the expectation is for transcribe_streaming to be called on input chunks comprising segment_length frames right-concatenated with right_context_length frames.

Parameters
  • input_dim (int) – dimension of input sequence frames passed to transcription network.

  • encoding_dim (int) – dimension of transcription- and prediction-network-generated encodings passed to joint network.

  • num_symbols (int) – cardinality of set of target tokens.

  • segment_length (int) – length of input segment expressed as number of frames.

  • right_context_length (int) – length of right context expressed as number of frames.

  • time_reduction_input_dim (int) – dimension to scale each element in input sequences to prior to applying time reduction block.

  • time_reduction_stride (int) – factor by which to reduce length of input sequence.

  • transformer_num_heads (int) – number of attention heads in each Emformer layer.

  • transformer_ffn_dim (int) – hidden layer dimension of each Emformer layer’s feedforward network.

  • transformer_num_layers (int) – number of Emformer layers to instantiate.

  • transformer_left_context_length (int) – length of left context considered by Emformer.

  • transformer_dropout (float) – Emformer dropout probability.

  • transformer_activation (str) – activation function to use in each Emformer layer’s feedforward network. Must be one of (“relu”, “gelu”, “silu”).

  • transformer_max_memory_size (int) – maximum number of memory elements to use.

  • transformer_weight_init_scale_strategy (str) – per-layer weight initialization scaling strategy. Must be one of (“depthwise”, “constant”, None).

  • transformer_tanh_on_mem (bool) – if True, applies tanh to memory elements.

  • symbol_embedding_dim (int) – dimension of each target token embedding.

  • num_lstm_layers (int) – number of LSTM layers to instantiate.

  • lstm_layer_norm (bool) – if True, enables layer normalization for LSTM layers.

  • lstm_layer_norm_epsilon (float) – value of epsilon to use in LSTM layer normalization layers.

  • lstm_dropout (float) – LSTM dropout probability.

Returns

Emformer RNN-T model.

Return type

RNNT

RNNTBeamSearch

class torchaudio.prototype.RNNTBeamSearch(model: torchaudio.prototype.rnnt.RNNT, blank: int, temperature: float = 1.0, hypo_sort_key: Optional[Callable[[torchaudio.prototype.rnnt_decoder.Hypothesis], float]] = None, step_max_tokens: int = 100)[source]

Beam search decoder for RNN-T model.

Parameters
  • model (RNNT) – RNN-T model to use.

  • blank (int) – index of blank token in vocabulary.

  • temperature (float, optional) – temperature to apply to joint network output. Larger values yield more uniform samples. (Default: 1.0)

  • hypo_sort_key (Callable[[Hypothesis], float] or None, optional) – callable that computes a score for a given hypothesis to rank hypotheses by. If None, defaults to callable that returns hypothesis score normalized by token sequence length. (Default: None)

  • step_max_tokens (int, optional) – maximum number of tokens to emit per input time step. (Default: 100)

forward(input: torch.Tensor, length: torch.Tensor, beam_width: int)List[torchaudio.prototype.rnnt_decoder.Hypothesis][source]

Performs beam search for the given input sequence.

T: number of frames; D: feature dimension of each frame.

Parameters
  • input (torch.Tensor) – sequence of input frames, with shape (T, D) or (1, T, D).

  • length (torch.Tensor) – number of valid frames in input sequence, with shape () or (1,).

  • beam_width (int) – beam size to use during search.

Returns

top-beam_width hypotheses found by beam search.

Return type

List[Hypothesis]

infer(input: torch.Tensor, length: torch.Tensor, beam_width: int, state: Optional[List[List[torch.Tensor]]] = None, hypothesis: Optional[torchaudio.prototype.rnnt_decoder.Hypothesis] = None)Tuple[List[torchaudio.prototype.rnnt_decoder.Hypothesis], List[List[torch.Tensor]]][source]

Performs beam search for the given input sequence in streaming mode.

T: number of frames; D: feature dimension of each frame.

Parameters
  • input (torch.Tensor) – sequence of input frames, with shape (T, D) or (1, T, D).

  • length (torch.Tensor) – number of valid frames in input sequence, with shape () or (1,).

  • beam_width (int) – beam size to use during search.

  • state (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing transcription network internal state generated in preceding invocation. (Default: None)

  • hypothesis (Hypothesis or None) – hypothesis from preceding invocation to seed search with. (Default: None)

Returns

List[Hypothesis]

top-beam_width hypotheses found by beam search.

List[List[torch.Tensor]]

list of lists of tensors representing transcription network internal state generated in current invocation.

Return type

(List[Hypothesis], List[List[torch.Tensor]])

Hypothesis

class torchaudio.prototype.Hypothesis(tokens: List[int], predictor_out: torch.Tensor, state: List[List[torch.Tensor]], score: float, alignment: List[int], blank: int, key: str)[source]

Represents hypothesis generated by beam search decoder RNNTBeamSearch.

Variables
  • tokens (List[int]) – Predicted sequence of tokens.

  • predictor_out (torch.Tensor) – Prediction network output.

  • state (List[List[torch.Tensor]]) – Prediction network internal state.

  • score (float) – Score of hypothesis.

  • alignment (List[int]) – Sequence of timesteps, with the i-th value mapping to the i-th predicted token in tokens.

  • blank (int) – Token index corresponding to blank token.

  • key (str) – Value used to determine equivalence in token sequences between Hypothesis instances.

References

1

Yangyang Shi, Yongqiang Wang, Chunyang Wu, Ching-Feng Yeh, Julian Chan, Frank Zhang, Duc Le, and Mike Seltzer. Emformer: efficient memory transformer based acoustic model for low latency streaming speech recognition. In ICASSP 2021 - 2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 6783–6787. 2021.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources