RNNT¶
- class torchaudio.models.RNNT[source]¶
Recurrent neural network transducer (RNN-T) model.
Note
To build the model, please use one of the factory functions,
emformer_rnnt_model()
oremformer_rnnt_base()
.See also
torchaudio.pipelines.RNNTBundle
: ASR pipeline with pre-trained models.- Parameters:
transcriber (torch.nn.Module) – transcription network.
predictor (torch.nn.Module) – prediction network.
joiner (torch.nn.Module) – joint network.
forward¶
- RNNT.forward(sources: Tensor, source_lengths: Tensor, targets: Tensor, target_lengths: Tensor, predictor_state: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, Tensor, List[List[Tensor]]] [source]¶
Forward pass for training.
B: batch size; T: maximum source sequence length in batch; U: maximum target sequence length in batch; D: feature dimension of each source sequence element.
- Parameters:
sources (torch.Tensor) – source frame sequences right-padded with right context, with shape (B, T, D).
source_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in
sources
.targets (torch.Tensor) – target sequences, with shape (B, U) and each element mapping to a target symbol.
target_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in
targets
.predictor_state (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing prediction network internal state generated in preceding invocation of
forward
. (Default:None
)
- Returns:
- torch.Tensor
joint network output, with shape (B, max output source length, max output target length, output_dim (number of target symbols)).
- torch.Tensor
output source lengths, with shape (B,) and i-th element representing number of valid elements along dim 1 for i-th batch element in joint network output.
- torch.Tensor
output target lengths, with shape (B,) and i-th element representing number of valid elements along dim 2 for i-th batch element in joint network output.
- List[List[torch.Tensor]]
output states; list of lists of tensors representing prediction network internal state generated in current invocation of
forward
.
- Return type:
(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
transcribe_streaming¶
- RNNT.transcribe_streaming(sources: Tensor, source_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]] [source]¶
Applies transcription network to sources in streaming mode.
B: batch size; T: maximum source sequence segment length in batch; D: feature dimension of each source sequence frame.
- Parameters:
sources (torch.Tensor) – source frame sequence segments right-padded with right context, with shape (B, T + right context length, D).
source_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in
sources
.state (List[List[torch.Tensor]] or None) – list of lists of tensors representing transcription network internal state generated in preceding invocation of
transcribe_streaming
.
- Returns:
- torch.Tensor
output frame sequences, with shape (B, T // time_reduction_stride, output_dim).
- torch.Tensor
output lengths, with shape (B,) and i-th element representing number of valid elements for i-th batch element in output.
- List[List[torch.Tensor]]
output states; list of lists of tensors representing transcription network internal state generated in current invocation of
transcribe_streaming
.
- Return type:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
transcribe¶
- RNNT.transcribe(sources: Tensor, source_lengths: Tensor) Tuple[Tensor, Tensor] [source]¶
Applies transcription network to sources in non-streaming mode.
B: batch size; T: maximum source sequence length in batch; D: feature dimension of each source sequence frame.
- Parameters:
sources (torch.Tensor) – source frame sequences right-padded with right context, with shape (B, T + right context length, D).
source_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in
sources
.
- Returns:
- torch.Tensor
output frame sequences, with shape (B, T // time_reduction_stride, output_dim).
- torch.Tensor
output lengths, with shape (B,) and i-th element representing number of valid elements for i-th batch element in output frame sequences.
- Return type:
predict¶
- RNNT.predict(targets: Tensor, target_lengths: Tensor, state: Optional[List[List[Tensor]]]) Tuple[Tensor, Tensor, List[List[Tensor]]] [source]¶
Applies prediction network to targets.
B: batch size; U: maximum target sequence length in batch; D: feature dimension of each target sequence frame.
- Parameters:
targets (torch.Tensor) – target sequences, with shape (B, U) and each element mapping to a target symbol, i.e. in range [0, num_symbols).
target_lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in
targets
.state (List[List[torch.Tensor]] or None) – list of lists of tensors representing internal state generated in preceding invocation of
predict
.
- Returns:
- torch.Tensor
output frame sequences, with shape (B, U, output_dim).
- torch.Tensor
output lengths, with shape (B,) and i-th element representing number of valid elements for i-th batch element in output.
- List[List[torch.Tensor]]
output states; list of lists of tensors representing internal state generated in current invocation of
predict
.
- Return type:
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])
join¶
- RNNT.join(source_encodings: Tensor, source_lengths: Tensor, target_encodings: Tensor, target_lengths: Tensor) Tuple[Tensor, Tensor, Tensor] [source]¶
Applies joint network to source and target encodings.
B: batch size; T: maximum source sequence length in batch; U: maximum target sequence length in batch; D: dimension of each source and target sequence encoding.
- Parameters:
source_encodings (torch.Tensor) – source encoding sequences, with shape (B, T, D).
source_lengths (torch.Tensor) – with shape (B,) and i-th element representing valid sequence length of i-th batch element in
source_encodings
.target_encodings (torch.Tensor) – target encoding sequences, with shape (B, U, D).
target_lengths (torch.Tensor) – with shape (B,) and i-th element representing valid sequence length of i-th batch element in
target_encodings
.
- Returns:
- torch.Tensor
joint network output, with shape (B, T, U, output_dim).
- torch.Tensor
output source lengths, with shape (B,) and i-th element representing number of valid elements along dim 1 for i-th batch element in joint network output.
- torch.Tensor
output target lengths, with shape (B,) and i-th element representing number of valid elements along dim 2 for i-th batch element in joint network output.
- Return type: