ForwardKLLoss¶
- class torchtune.modules.loss.ForwardKLLoss(ignore_index: int = - 100)[source]¶
The Kullback-Leibler divergence loss for valid indexes. Implementation of https://github.com/jongwooko/distillm/blob/17c0f98bc263b1861a02d5df578c84aea652ee65/distillm/losses.py
- Parameters:
ignore_index (int) – Specifies a target value that is ignored and does not contribute to the input gradient. The loss is divided over non-ignored targets. Default: -100.
- forward(student_logits: Tensor, teacher_logits: Tensor, labels: Tensor) Tensor [source]¶
- Parameters:
student_logits (torch.Tensor) – logits from student model of shape (batch_size*num_tokens, vocab_size).
teacher_logits (torch.Tensor) – logits from teacher model of shape (batch_size*num_tokens, vocab_size).
labels (torch.Tensor) – Ground truth labels of shape (batch_size, vocab_size).
- Returns:
KL divergence loss of shape (1,).
- Return type: