CEWithChunkedOutputLoss¶
- class torchtune.modules.loss.CEWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]¶
Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.
Whenever the model is trained with bf16, before running CE, we have to upcast it to fp32 for better accuracy and stability. When upcasting happens, the memory usage doubles. Models like llama3 have large vocabulary size and, therefore, have a large output tensor of shape
(bsz, num_tokens, vocab_size)
. If we chunk on the token level, you can still compute the cross entropy normally, but upcasting only one chunk at a time saves considerable memory.The CE and upcasting have to be compiled together for better performance. When using this class, we recommend using
torch.compile()
only on the methodcompute_cross_entropy
. The gains from chunking won’t be realized if you compile the entire class.For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
- compute_cross_entropy(logits: Tensor, labels: Tensor) Tensor [source]¶
Upcast logits to fp32 and compute cross entropy loss.
- forward(logits: List[Tensor], labels: Tensor) Tensor [source]¶
- Parameters:
logits (List[torch.Tensor]) – List of chunked logits of length
self.num_output_chunks
, where each chunk has shape(batch_size, num_tokens / num_output_chunks, vocab_size)
.labels (torch.Tensor) – Ground truth labels of shape
(batch_size, num_tokens)
.
- Returns:
Cross entropy loss of shape (1,).
- Return type:
Example
>>> loss_fn = ChunkedCrossEntropyLoss() >>> >>> h = torch.tensor([bsz, num_tokens, dim]) >>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)] >>> >>> labels = torch.tensor([bsz, num_tokens]) >>> loss = loss_fn(output_chunks, labels)