Shortcuts

CEWithChunkedOutputLoss

class torchtune.modules.loss.CEWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]

Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.

Whenever the model is trained with bf16, before running CE, we have to upcast it to fp32 for better accuracy and stability. When upcasting happens, the memory usage doubles. Models like llama3 have large vocabulary size and, therefore, have a large output tensor of shape (bsz, num_tokens, vocab_size). If we chunk on the token level, you can still compute the cross entropy normally, but upcasting only one chunk at a time saves considerable memory.

The CE and upcasting have to be compiled together for better performance. When using this class, we recommend using torch.compile() only on the method compute_cross_entropy. The gains from chunking won’t be realized if you compile the entire class.

For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390

compute_cross_entropy(logits: Tensor, labels: Tensor) Tensor[source]

Upcast logits to fp32 and compute cross entropy loss.

forward(logits: List[Tensor], labels: Tensor) Tensor[source]
Parameters:
  • logits (List[torch.Tensor]) – List of chunked logits of length self.num_output_chunks, where each chunk has shape (batch_size, num_tokens / num_output_chunks, vocab_size).

  • labels (torch.Tensor) – Ground truth labels of shape (batch_size, num_tokens).

Returns:

Cross entropy loss of shape (1,).

Return type:

torch.Tensor

Example

>>> loss_fn = ChunkedCrossEntropyLoss()
>>>
>>> h = torch.tensor([bsz, num_tokens, dim])
>>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)]
>>>
>>> labels = torch.tensor([bsz, num_tokens])
>>> loss = loss_fn(output_chunks, labels)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources