Shortcuts

torchtune.training

Checkpointing

torchtune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of checkpointing, please see the checkpointing deep-dive.

FullModelHFCheckpointer

Checkpointer which reads and writes checkpoints in HF's format.

FullModelMetaCheckpointer

Checkpointer which reads and writes checkpoints in Meta's format.

FullModelTorchTuneCheckpointer

Checkpointer which reads and writes checkpoints in a format compatible with torchtune.

ModelType

ModelType is used by the checkpointer to distinguish between different model architectures.

FormattedCheckpointFiles

This class gives a more concise way to represent a list of filenames of the format file_{i}_of_{n_files}.pth.

update_state_dict_for_classifier

Validates the state dict for checkpoint loading for a classifier model.

Reduced Precision

Utilities for working in a reduced precision setting.

get_dtype

Get the torch.dtype corresponding to the given precision string.

set_default_dtype

Context manager to set torch's default dtype.

validate_expected_param_dtype

Validates that all input parameters have the expected dtype.

get_quantizer_mode

Given a quantizer object, returns a string that specifies the type of quantization.

Distributed

Utilities for enabling and working with distributed training.

init_distributed

Initialize process group required for torch.distributed.

is_distributed

Check if all environment variables required to initialize torch.distributed are set and distributed is properly installed.

get_world_size_and_rank

Function that gets the current world size (aka total number of ranks) and rank number of the current process in the default process group.

gather_cpu_state_dict

Converting sharded state dict into a full state dict on CPU Returning non-empty result only on rank0 to avoid peaking CPU memory

Memory Management

Utilities to reduce memory consumption during training.

apply_selective_activation_checkpointing

Utility to setup activation checkpointing and wrap the model for checkpointing.

set_activation_checkpointing

Utility to apply activation checkpointing to the passed-in model.

OptimizerInBackwardWrapper

A bare-bones class meant for checkpoint save and load for optimizers running in backward.

create_optim_in_bwd_wrapper

Create a wrapper for optimizer step running in backward.

register_optim_in_bwd_hooks

Register hooks for optimizer step running in backward.

Schedulers

Utilities to control lr during the training process.

get_cosine_schedule_with_warmup

Create a learning rate schedule that linearly increases the learning rate from 0.0 to lr over num_warmup_steps, then decreases to 0.0 on a cosine schedule over the remaining num_training_steps-num_warmup_steps (assuming num_cycles = 0.5).

get_lr

Full_finetune_distributed and full_finetune_single_device assume all optimizers have the same LR, here to validate whether all the LR are the same and return if True.

Metric Logging

Various logging utilities.

metric_logging.CometLogger

Logger for use w/ Comet (https://www.comet.com/site/).

metric_logging.WandBLogger

Logger for use w/ Weights and Biases application (https://wandb.ai/).

metric_logging.TensorBoardLogger

Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html).

metric_logging.StdoutLogger

Logger to standard output.

metric_logging.DiskLogger

Logger to disk.

Performance and Profiling

torchtune provides utilities to profile and debug the memory and performance of your finetuning job.

get_memory_stats

Computes a memory summary for the passed in device.

log_memory_stats

Logs a dict containing memory stats to the logger.

setup_torch_profiler

Sets up profile and returns the profiler config with post-setup updates.

Miscellaneous

get_unmasked_sequence_lengths

Returns the sequence lengths for each batch element, excluding masked tokens.

set_seed

Function that sets seed for pseudo-random number generators across commonly used libraries.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources