ForwardKLWithChunkedOutputLoss¶
- class torchtune.modules.loss.ForwardKLWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]¶
Forward KL with chunked outputs that saves memory by only upcasting one chunk at a time.
Since the model is trained with bf16, before computing KL divergence, we have to upcast it to fp32 for better accuracy and stability. When upcasting happens, the memory usage doubles. Models like llama3 have large vocabulary size and, therefore, have a large output result (bsz, num_tokens, vocab_size). If we chunk on the token level, you can still compute the cross entropy normally, but upcasting only one chunk at a time saves considerable memory.
- Parameters:
num_output_chunks (int) – Number of chunks to chunk the output into. Each chunk has shape (batch_size, num_tokens / num_output_chunks, vocab_size). Default: 8
ignore_index (int) – Specifies a target value that is ignored and does not contribute to the input gradient. The loss is divided over non-ignored targets. Default: -100
- forward(student_logits: List[Tensor], teacher_logits: List[Tensor], labels: Tensor) Tensor [source]¶
- Parameters:
student_logits (List[torch.Tensor]) – List of chunked logits from student model of length
self.num_output_chunks
, where each chunk has shape (batch_size, num_tokens / num_output_chunks, vocab_size).teacher_logits (List[torch.Tensor]) – List of chunked logits from teacher model of length
self.num_output_chunks
, where each chunk has shape (batch_size, num_tokens / num_output_chunks, vocab_size).labels (torch.Tensor) – Ground truth labels of shape (batch_size, num_tokens).
- Returns:
KL divergence loss of shape (1,).
- Return type:
Example
>>> loss_fn = ForwardKLWithChunkedOutputLoss() >>> >>> h = torch.tensor([bsz, num_tokens, dim]) >>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)] >>> teacher_chunks = [teacher_model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)] >>> labels = torch.tensor([bsz, num_tokens]) >>> loss = loss_fn(output_chunks, teacher_chunks, labels)