Shortcuts

torchtune.utils

Checkpointing

TorchTune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of checkpointing, please see the checkpointing deep-dive.

FullModelHFCheckpointer

Checkpointer which reads and writes checkpoints in HF's format.

FullModelMetaCheckpointer

Checkpointer which reads and writes checkpoints in Meta's format.

Distributed

Utilities for enabling and working with distributed training.

init_distributed

Initialize torch.distributed.

get_world_size_and_rank

Function that gets the current world size (aka total number of ranks) and rank number of the current trainer.

Reduced Precision

Utilities for working in a reduced precision setting.

get_dtype

Get the torch.dtype corresponding to the given precision string.

list_dtypes

Return a list of supported dtypes for finetuning.

Memory Management

Utilities to reduce memory consumption during training.

set_activation_checkpointing

Utility to setup activation checkpointing and wrap the model for checkpointing.

Performance and Profiling

TorchTune provides utilities to profile and debug the performance of your finetuning job.

profiler

Utility component that wraps around torch.profiler to profile model's operators.

Metric Logging

Various logging utilities.

metric_logging.WandBLogger

Logger for use w/ Weights and Biases application (https://wandb.ai/).

metric_logging.TensorBoardLogger

Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html).

metric_logging.StdoutLogger

Logger to standard output.

metric_logging.DiskLogger

Logger to disk.

Data

Utilities for working with data and datasets.

padded_collate

Pad a batch of sequences to the longest sequence length in the batch, and convert integer lists to tensors.

Miscellaneous

TuneRecipeArgumentParser

A helpful utility subclass of the argparse.ArgumentParser that adds a builtin argument "config".

get_logger

Get a logger with a stream handler.

get_device

Function that takes or device or device string, verifies it's correct and availabe given the machine and distributed settings, and returns a torch.device.

set_seed

Function that sets seed for pseudo-random number generators across commonly used libraries.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources