torchtune.training¶
Checkpointing¶
torchtune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of checkpointing, please see the checkpointing deep-dive.
Checkpointer which reads and writes checkpoints in HF's format. |
|
Checkpointer which reads and writes checkpoints in Meta's format. |
|
Checkpointer which reads and writes checkpoints in a format compatible with torchtune. |
|
ModelType is used by the checkpointer to distinguish between different model architectures. |
|
This class gives a more concise way to represent a list of filenames of the format |
|
Validates the state dict for checkpoint loading for a classifier model. |
Reduced Precision¶
Utilities for working in a reduced precision setting.
Get the torch.dtype corresponding to the given precision string. |
|
Context manager to set torch's default dtype. |
|
Validates that all input parameters have the expected dtype. |
|
Given a quantizer object, returns a string that specifies the type of quantization. |
Distributed¶
Utilities for enabling and working with distributed training.
Initialize process group required for |
|
Check if all environment variables required to initialize torch.distributed are set and distributed is properly installed. |
|
Converting sharded state dict into a full state dict on CPU Returning non-empty result only on rank0 to avoid peaking CPU memory |
Memory Management¶
Utilities to reduce memory consumption during training.
Utility to setup activation checkpointing and wrap the model for checkpointing. |
|
Utility to apply activation checkpointing to the passed-in model. |
|
A bare-bones class meant for checkpoint save and load for optimizers running in backward. |
|
Create a wrapper for optimizer step running in backward. |
|
Register hooks for optimizer step running in backward. |
Schedulers¶
Utilities to control lr during the training process.
Create a learning rate schedule that linearly increases the learning rate from 0.0 to lr over |
|
Full_finetune_distributed and full_finetune_single_device assume all optimizers have the same LR, here to validate whether all the LR are the same and return if True. |
Metric Logging¶
Various logging utilities.
Logger for use w/ Comet (https://www.comet.com/site/). |
|
Logger for use w/ Weights and Biases application (https://wandb.ai/). |
|
Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html). |
|
Logger to standard output. |
|
Logger to disk. |
Performance and Profiling¶
torchtune provides utilities to profile and debug the memory and performance of your finetuning job.
Computes a memory summary for the passed in device. |
|
Logs a dict containing memory stats to the logger. |
|
Sets up |
Miscellaneous¶
Returns the sequence lengths for each batch element, excluding masked tokens. |
|
Function that sets seed for pseudo-random number generators across commonly used libraries. |