Shortcuts

torchtune.utils

Checkpointing

TorchTune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of checkpointing, please see the checkpointing deep-dive.

FullModelHFCheckpointer

Checkpointer which reads and writes checkpoints in HF's format.

FullModelMetaCheckpointer

Checkpointer which reads and writes checkpoints in Meta's format.

Distributed

Utilities for enabling and working with distributed training.

FSDPPolicyType

A datatype for a function that can be used as an FSDP wrapping policy.

init_distributed

Initialize torch.distributed.

get_world_size_and_rank

Function that gets the current world size (aka total number of ranks) and rank number of the current trainer.

get_full_finetune_fsdp_wrap_policy

Retrieves an FSDP wrapping policy based on the specified flags memory_efficient_fsdp_wrap and modules_to_wrap.

Reduced Precision

Utilities for working in a reduced precision setting.

get_dtype

Get the torch.dtype corresponding to the given precision string.

list_dtypes

Return a list of supported dtypes for finetuning.

Memory Management

Utilities to reduce memory consumption during training.

set_activation_checkpointing

Utility to apply activation checkpointing to the passed in model.

Performance and Profiling

TorchTune provides utilities to profile and debug the performance of your finetuning job.

profiler

Utility component that wraps around torch.profiler to profile model's operators.

Metric Logging

Various logging utilities.

metric_logging.WandBLogger

Logger for use w/ Weights and Biases application (https://wandb.ai/).

metric_logging.TensorBoardLogger

Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html).

metric_logging.StdoutLogger

Logger to standard output.

metric_logging.DiskLogger

Logger to disk.

Data

Utilities for working with data and datasets.

padded_collate

Pad a batch of sequences to the longest sequence length in the batch, and convert integer lists to tensors.

Miscellaneous

TuneRecipeArgumentParser

A helpful utility subclass of the argparse.ArgumentParser that adds a builtin argument "config".

get_logger

Get a logger with a stream handler.

get_device

Function that takes or device or device string, verifies it's correct and availabe given the machine and distributed settings, and returns a torch.device.

set_seed

Function that sets seed for pseudo-random number generators across commonly used libraries.

generate

Generates tokens from a model conditioned on a prompt.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources