torchtune.utils¶
Checkpointing¶
torchtune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of checkpointing, please see the checkpointing deep-dive.
Checkpointer which reads and writes checkpoints in HF's format. |
|
Checkpointer which reads and writes checkpoints in Meta's format. |
|
Checkpointer which reads and writes checkpoints in a format compatible with torchtune. |
|
ModelType is used by the checkpointer to distinguish between different model architectures. |
Distributed¶
Utilities for enabling and working with distributed training.
A datatype for a function that can be used as an FSDP wrapping policy. |
|
Initialize process group required for |
|
Check if all environment variables required to initialize torch.distributed are set and distributed is properly installed. |
|
Function that gets the current world size (aka total number of ranks) and rank number of the current process in the default process group. |
|
Retrieves an FSDP wrapping policy based on the specified flags |
|
A default policy for wrapping models trained with LoRA using FSDP. |
Reduced Precision¶
Utilities for working in a reduced precision setting.
Get the torch.dtype corresponding to the given precision string. |
|
Context manager to set torch's default dtype. |
|
Validates that all input parameters have the expected dtype. |
|
Given a quantizer object, returns a string that specifies the type of quantization. |
Memory Management¶
Utilities to reduce memory consumption during training.
Utility to apply activation checkpointing to the passed-in model. |
|
A bare-bones class meant for checkpoint save and load for optimizers running in backward. |
|
Create a wrapper for optimizer step running in backward. |
|
Register hooks for optimizer step running in backward. |
Performance and Profiling¶
torchtune provides utilities to profile and debug the memory and performance of your finetuning job.
Computes a memory summary for the passed in device. |
|
Logs a dict containing memory stats to the logger. |
|
Sets up |
Metric Logging¶
Various logging utilities.
Logger for use w/ Weights and Biases application (https://wandb.ai/). |
|
Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html). |
|
Logger to standard output. |
|
Logger to disk. |
Data¶
Utilities for working with data and datasets.
Pad a batch of sequences to the longest sequence length in the batch, and convert integer lists to tensors. |
|
Pad a batch of sequences for Direct Preference Optimization (DPO). |
Miscellaneous¶
Get a logger with a stream handler. |
|
Function that takes an optional device string, verifies it's correct and available given the machine and distributed settings, and returns a torch.device. |
|
Function that sets seed for pseudo-random number generators across commonly used libraries. |
|
Generates tokens from a model conditioned on a prompt. |
|
Check if torch version is greater than or equal to the given version. |
|
A helpful utility subclass of the |