RNNTLoss¶
- class torchaudio.transforms.RNNTLoss(blank: int = -1, clamp: float = -1.0, reduction: str = 'mean', fused_log_softmax: bool = True)[source]¶
Compute the RNN Transducer loss from Sequence Transduction with Recurrent Neural Networks [Graves, 2012].
The RNN Transducer loss extends the CTC loss by defining a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output dependencies.
- Parameters:
blank (int, optional) – blank label (Default:
-1
)clamp (float, optional) – clamp for gradients (Default:
-1
)reduction (string, optional) – Specifies the reduction to apply to the output:
"none"
|"mean"
|"sum"
. (Default:"mean"
)fused_log_softmax (bool) – set to False if calling log_softmax outside of loss (Default:
True
)
- Example
>>> # Hypothetical values >>> logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], >>> [0.1, 0.1, 0.6, 0.1, 0.1], >>> [0.1, 0.1, 0.2, 0.8, 0.1]], >>> [[0.1, 0.6, 0.1, 0.1, 0.1], >>> [0.1, 0.1, 0.2, 0.1, 0.1], >>> [0.7, 0.1, 0.2, 0.1, 0.1]]]], >>> dtype=torch.float32, >>> requires_grad=True) >>> targets = torch.tensor([[1, 2]], dtype=torch.int) >>> logit_lengths = torch.tensor([2], dtype=torch.int) >>> target_lengths = torch.tensor([2], dtype=torch.int) >>> transform = transforms.RNNTLoss(blank=0) >>> loss = transform(logits, targets, logit_lengths, target_lengths) >>> loss.backward()
- forward(logits: Tensor, targets: Tensor, logit_lengths: Tensor, target_lengths: Tensor)[source]¶
- Parameters:
logits (Tensor) – Tensor of dimension (batch, max seq length, max target length + 1, class) containing output from joiner
targets (Tensor) – Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor) – Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor) – Tensor of dimension (batch) containing lengths of targets for each sequence
- Returns:
Loss with the reduction option applied. If
reduction
is"none"
, then size (batch), otherwise scalar.- Return type:
Tensor