Shortcuts

ForwardKLLoss

class torchtune.modules.loss.ForwardKLLoss(ignore_index: int = - 100)[source]

The Kullback-Leibler divergence loss for valid indexes. Implementation of https://github.com/jongwooko/distillm/blob/17c0f98bc263b1861a02d5df578c84aea652ee65/distillm/losses.py

Parameters:

ignore_index (int) – Specifies a target value that is ignored and does not contribute to the input gradient. The loss is divided over non-ignored targets. Default: -100.

forward(student_logits: Tensor, teacher_logits: Tensor, labels: Tensor, normalize: bool = True) Tensor[source]
Parameters:
  • student_logits (torch.Tensor) – logits from student model of shape (batch_size*num_tokens, vocab_size).

  • teacher_logits (torch.Tensor) – logits from teacher model of shape (batch_size*num_tokens, vocab_size).

  • labels (torch.Tensor) – Ground truth labels of shape (batch_size, vocab_size).

  • normalize (bool) – Whether to normalize the loss by the number of unmasked elements.

Returns:

KL divergence loss of shape (1,).

Return type:

torch.Tensor

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources