Shortcuts

RNNT

class torchaudio.models.RNNT[source]

Recurrent neural network transducer (RNN-T) model.

Note

To build the model, please use one of the factory functions, emformer_rnnt_model() or emformer_rnnt_base().

See also

torchaudio.pipelines.RNNTBundle: ASR pipeline with pre-trained models.

Parameters:

forward

RNNT.forward(sources: Tensor, source_lengths: Tensor, targets: Tensor, target_lengths: Tensor, predictor_state: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, Tensor, List[List[Tensor]]][source]

Forward pass for training.

B: batch size; T: maximum source sequence length in batch; U: maximum target sequence length in batch; D: feature dimension of each source sequence element.

Parameters:
  • sources (torch.Tensor) – source frame sequences right-padded with right context, with shape (B, T, D).

  • source_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in sources.

  • targets (torch.Tensor) – target sequences, with shape (B, U) and each element mapping to a target symbol.

  • target_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in targets.

  • predictor_state (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing prediction network internal state generated in preceding invocation of forward. (Default: None)

Returns:

torch.Tensor

joint network output, with shape (B, max output source length, max output target length, output_dim (number of target symbols)).

torch.Tensor

output source lengths, with shape (B,) and i-th element representing number of valid elements along dim 1 for i-th batch element in joint network output.

torch.Tensor

output target lengths, with shape (B,) and i-th element representing number of valid elements along dim 2 for i-th batch element in joint network output.

List[List[torch.Tensor]]

output states; list of lists of tensors representing prediction network internal state generated in current invocation of forward.

Return type:

(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

transcribe_streaming

RNNT.transcribe_streaming(sources: Tensor, source_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]][source]

Applies transcription network to sources in streaming mode.

B: batch size; T: maximum source sequence segment length in batch; D: feature dimension of each source sequence frame.

Parameters:
  • sources (torch.Tensor) – source frame sequence segments right-padded with right context, with shape (B, T + right context length, D).

  • source_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in sources.

  • state (List[List[torch.Tensor]] or None) – list of lists of tensors representing transcription network internal state generated in preceding invocation of transcribe_streaming.

Returns:

torch.Tensor

output frame sequences, with shape (B, T // time_reduction_stride, output_dim).

torch.Tensor

output lengths, with shape (B,) and i-th element representing number of valid elements for i-th batch element in output.

List[List[torch.Tensor]]

output states; list of lists of tensors representing transcription network internal state generated in current invocation of transcribe_streaming.

Return type:

(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

transcribe

RNNT.transcribe(sources: Tensor, source_lengths: Tensor) Tuple[Tensor, Tensor][source]

Applies transcription network to sources in non-streaming mode.

B: batch size; T: maximum source sequence length in batch; D: feature dimension of each source sequence frame.

Parameters:
  • sources (torch.Tensor) – source frame sequences right-padded with right context, with shape (B, T + right context length, D).

  • source_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in sources.

Returns:

torch.Tensor

output frame sequences, with shape (B, T // time_reduction_stride, output_dim).

torch.Tensor

output lengths, with shape (B,) and i-th element representing number of valid elements for i-th batch element in output frame sequences.

Return type:

(torch.Tensor, torch.Tensor)

predict

RNNT.predict(targets: Tensor, target_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]][source]

Applies prediction network to targets.

B: batch size; U: maximum target sequence length in batch; D: feature dimension of each target sequence frame.

Parameters:
  • targets (torch.Tensor) – target sequences, with shape (B, U) and each element mapping to a target symbol, i.e. in range [0, num_symbols).

  • target_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in targets.

  • state (List[List[torch.Tensor]] or None) – list of lists of tensors representing internal state generated in preceding invocation of predict.

Returns:

torch.Tensor

output frame sequences, with shape (B, U, output_dim).

torch.Tensor

output lengths, with shape (B,) and i-th element representing number of valid elements for i-th batch element in output.

List[List[torch.Tensor]]

output states; list of lists of tensors representing internal state generated in current invocation of predict.

Return type:

(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

join

RNNT.join(source_encodings: Tensor, source_lengths: Tensor, target_encodings: Tensor, target_lengths: Tensor) Tuple[Tensor, Tensor, Tensor][source]

Applies joint network to source and target encodings.

B: batch size; T: maximum source sequence length in batch; U: maximum target sequence length in batch; D: dimension of each source and target sequence encoding.

Parameters:
  • source_encodings (torch.Tensor) – source encoding sequences, with shape (B, T, D).

  • source_lengths (torch.Tensor) – with shape (B,) and i-th element representing valid sequence length of i-th batch element in source_encodings.

  • target_encodings (torch.Tensor) – target encoding sequences, with shape (B, U, D).

  • target_lengths (torch.Tensor) – with shape (B,) and i-th element representing valid sequence length of i-th batch element in target_encodings.

Returns:

torch.Tensor

joint network output, with shape (B, T, U, output_dim).

torch.Tensor

output source lengths, with shape (B,) and i-th element representing number of valid elements along dim 1 for i-th batch element in joint network output.

torch.Tensor

output target lengths, with shape (B,) and i-th element representing number of valid elements along dim 2 for i-th batch element in joint network output.

Return type:

(torch.Tensor, torch.Tensor, torch.Tensor)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources