Shortcuts

torchtune.training

Checkpointing

torchtune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of checkpointing, please see the checkpointing deep-dive.

FullModelHFCheckpointer

Checkpointer which reads and writes checkpoints in HF's format.

FullModelMetaCheckpointer

Checkpointer which reads and writes checkpoints in Meta's format.

FullModelTorchTuneCheckpointer

Checkpointer which reads and writes checkpoints in a format compatible with torchtune.

ModelType

ModelType is used by the checkpointer to distinguish between different model architectures.

update_state_dict_for_classifier

Validates the state dict for checkpoint loading for a classifier model.

Reduced Precision

Utilities for working in a reduced precision setting.

get_dtype

Get the torch.dtype corresponding to the given precision string.

set_default_dtype

Context manager to set torch's default dtype.

validate_expected_param_dtype

Validates that all input parameters have the expected dtype.

get_quantizer_mode

Given a quantizer object, returns a string that specifies the type of quantization.

Distributed

Utilities for enabling and working with distributed training.

FSDPPolicyType

A datatype for a function that can be used as an FSDP wrapping policy.

init_distributed

Initialize process group required for torch.distributed.

is_distributed

Check if all environment variables required to initialize torch.distributed are set and distributed is properly installed.

get_world_size_and_rank

Function that gets the current world size (aka total number of ranks) and rank number of the current process in the default process group.

get_full_finetune_fsdp_wrap_policy

Retrieves an FSDP wrapping policy based on the specified flags memory_efficient_fsdp_wrap and modules_to_wrap.

lora_fsdp_wrap_policy

A default policy for wrapping models trained with LoRA using FSDP.

Memory Management

Utilities to reduce memory consumption during training.

apply_selective_activation_checkpointing

Utility to setup activation checkpointing and wrap the model for checkpointing.

set_activation_checkpointing

Utility to apply activation checkpointing to the passed-in model.

OptimizerInBackwardWrapper

A bare-bones class meant for checkpoint save and load for optimizers running in backward.

create_optim_in_bwd_wrapper

Create a wrapper for optimizer step running in backward.

register_optim_in_bwd_hooks

Register hooks for optimizer step running in backward.

Metric Logging

Various logging utilities.

metric_logging.CometLogger

Logger for use w/ Comet (https://www.comet.com/site/).

metric_logging.WandBLogger

Logger for use w/ Weights and Biases application (https://wandb.ai/).

metric_logging.TensorBoardLogger

Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html).

metric_logging.StdoutLogger

Logger to standard output.

metric_logging.DiskLogger

Logger to disk.

Performance and Profiling

torchtune provides utilities to profile and debug the memory and performance of your finetuning job.

get_memory_stats

Computes a memory summary for the passed in device.

log_memory_stats

Logs a dict containing memory stats to the logger.

setup_torch_profiler

Sets up profile and returns the profiler config with post-setup updates.

Miscellaneous

get_unmasked_sequence_lengths

Returns the sequence lengths for each batch element, excluding masked tokens.

set_seed

Function that sets seed for pseudo-random number generators across commonly used libraries.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources