RNNT¶
 class torchaudio.models.RNNT[source]¶
Recurrent neural network transducer (RNNT) model.
Note
To build the model, please use one of the factory functions.
See also
torchaudio.pipelines.RNNTBundle
: ASR pipeline with pretrained models. Parameters:
transcriber (torch.nn.Module) – transcription network.
predictor (torch.nn.Module) – prediction network.
joiner (torch.nn.Module) – joint network.
Methods¶
forward¶
 RNNT.forward(sources: Tensor, source_lengths: Tensor, targets: Tensor, target_lengths: Tensor, predictor_state: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, Tensor, List[List[Tensor]]] [source]¶
Forward pass for training.
B: batch size; T: maximum source sequence length in batch; U: maximum target sequence length in batch; D: feature dimension of each source sequence element.
 Parameters:
sources (torch.Tensor) – source frame sequences rightpadded with right context, with shape (B, T, D).
source_lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
sources
.targets (torch.Tensor) – target sequences, with shape (B, U) and each element mapping to a target symbol.
target_lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
targets
.predictor_state (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing prediction network internal state generated in preceding invocation of
forward
. (Default:None
)
 Returns:
 torch.Tensor
joint network output, with shape (B, max output source length, max output target length, output_dim (number of target symbols)).
 torch.Tensor
output source lengths, with shape (B,) and ith element representing number of valid elements along dim 1 for ith batch element in joint network output.
 torch.Tensor
output target lengths, with shape (B,) and ith element representing number of valid elements along dim 2 for ith batch element in joint network output.
 List[List[torch.Tensor]]
output states; list of lists of tensors representing prediction network internal state generated in current invocation of
forward
.
 Return type:
(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
transcribe_streaming¶
 RNNT.transcribe_streaming(sources: Tensor, source_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]] [source]¶
Applies transcription network to sources in streaming mode.
B: batch size; T: maximum source sequence segment length in batch; D: feature dimension of each source sequence frame.
 Parameters:
sources (torch.Tensor) – source frame sequence segments rightpadded with right context, with shape (B, T + right context length, D).
source_lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
sources
.state (List[List[torch.Tensor]] or None) – list of lists of tensors representing transcription network internal state generated in preceding invocation of
transcribe_streaming
.
 Returns:
 torch.Tensor
output frame sequences, with shape (B, T // time_reduction_stride, output_dim).
 torch.Tensor
output lengths, with shape (B,) and ith element representing number of valid elements for ith batch element in output.
 List[List[torch.Tensor]]
output states; list of lists of tensors representing transcription network internal state generated in current invocation of
transcribe_streaming
.
 Return type:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
transcribe¶
 RNNT.transcribe(sources: Tensor, source_lengths: Tensor) Tuple[Tensor, Tensor] [source]¶
Applies transcription network to sources in nonstreaming mode.
B: batch size; T: maximum source sequence length in batch; D: feature dimension of each source sequence frame.
 Parameters:
sources (torch.Tensor) – source frame sequences rightpadded with right context, with shape (B, T + right context length, D).
source_lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
sources
.
 Returns:
 torch.Tensor
output frame sequences, with shape (B, T // time_reduction_stride, output_dim).
 torch.Tensor
output lengths, with shape (B,) and ith element representing number of valid elements for ith batch element in output frame sequences.
 Return type:
predict¶
 RNNT.predict(targets: Tensor, target_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]] [source]¶
Applies prediction network to targets.
B: batch size; U: maximum target sequence length in batch; D: feature dimension of each target sequence frame.
 Parameters:
targets (torch.Tensor) – target sequences, with shape (B, U) and each element mapping to a target symbol, i.e. in range [0, num_symbols).
target_lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
targets
.state (List[List[torch.Tensor]] or None) – list of lists of tensors representing internal state generated in preceding invocation of
predict
.
 Returns:
 torch.Tensor
output frame sequences, with shape (B, U, output_dim).
 torch.Tensor
output lengths, with shape (B,) and ith element representing number of valid elements for ith batch element in output.
 List[List[torch.Tensor]]
output states; list of lists of tensors representing internal state generated in current invocation of
predict
.
 Return type:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
join¶
 RNNT.join(source_encodings: Tensor, source_lengths: Tensor, target_encodings: Tensor, target_lengths: Tensor) Tuple[Tensor, Tensor, Tensor] [source]¶
Applies joint network to source and target encodings.
B: batch size; T: maximum source sequence length in batch; U: maximum target sequence length in batch; D: dimension of each source and target sequence encoding.
 Parameters:
source_encodings (torch.Tensor) – source encoding sequences, with shape (B, T, D).
source_lengths (torch.Tensor) – with shape (B,) and ith element representing valid sequence length of ith batch element in
source_encodings
.target_encodings (torch.Tensor) – target encoding sequences, with shape (B, U, D).
target_lengths (torch.Tensor) – with shape (B,) and ith element representing valid sequence length of ith batch element in
target_encodings
.
 Returns:
 torch.Tensor
joint network output, with shape (B, T, U, output_dim).
 torch.Tensor
output source lengths, with shape (B,) and ith element representing number of valid elements along dim 1 for ith batch element in joint network output.
 torch.Tensor
output target lengths, with shape (B,) and ith element representing number of valid elements along dim 2 for ith batch element in joint network output.
 Return type:
Factory Functions¶
Builds Emformerbased 

Builds basic version of Emformerbased 
Prototype Factory Functions¶
Builds Conformerbased recurrent neural network transducer (RNNT) model. 

Builds basic version of Conformer RNNT model. 