Shortcuts

ForwardKLWithChunkedOutputLoss

class torchtune.modules.loss.ForwardKLWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]

Forward KL with chunked outputs that saves memory by only upcasting one chunk at a time.

Since the model is trained with bf16, before computing KL divergence, we have to upcast it to fp32 for better accuracy and stability. When upcasting happens, the memory usage doubles. Models like llama3 have large vocabulary size and, therefore, have a large output result (bsz, num_tokens, vocab_size). If we chunk on the token level, you can still compute the cross entropy normally, but upcasting only one chunk at a time saves considerable memory.

Parameters:
  • num_output_chunks (int) – Number of chunks to chunk the output into. Each chunk has shape (batch_size, num_tokens / num_output_chunks, vocab_size). Default: 8

  • ignore_index (int) – Specifies a target value that is ignored and does not contribute to the input gradient. The loss is divided over non-ignored targets. Default: -100

forward(student_logits: List[Tensor], teacher_logits: List[Tensor], labels: Tensor) Tensor[source]
Parameters:
  • student_logits (List[torch.Tensor]) – List of chunked logits from student model of length self.num_output_chunks, where each chunk has shape (batch_size, num_tokens / num_output_chunks, vocab_size).

  • teacher_logits (List[torch.Tensor]) – List of chunked logits from teacher model of length self.num_output_chunks, where each chunk has shape (batch_size, num_tokens / num_output_chunks, vocab_size).

  • labels (torch.Tensor) – Ground truth labels of shape (batch_size, num_tokens).

Returns:

KL divergence loss of shape (1,).

Return type:

torch.Tensor

Example

>>> loss_fn = ForwardKLWithChunkedOutputLoss()
>>>
>>> h = torch.tensor([bsz, num_tokens, dim])
>>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)]
>>> teacher_chunks = [teacher_model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)]
>>> labels = torch.tensor([bsz, num_tokens])
>>> loss = loss_fn(output_chunks, teacher_chunks, labels)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources