Training related utilities. These are independent of the framework and can be used as needed.

Data Utils

AllDatasetBatchesIterator AllDatasetBatchesIterator returns a dict containing batches from all dataloaders.
CudaDataPrefetcher CudaDataPrefetcher prefetches batches and moves them to the device.
InOrderIterator InOrderIterator returns all batches from a single dataset till it is exhausted and then moves to the next one.
MultiDataLoader MultiDataLoader cycles through individual dataloaders passed to it.
MultiIterator MultiIterator defines the iteration logic to get a batch, given batches from all individual dataloaders.
RandomizedBatchSamplerIterator RandomizedBatchSamplerIterator randomly samples from each dataset using the provided weights.
RoundRobinIterator RoundRobinIterator cycles over the dataloader one by one.

Device Utils

get_device_from_env Function that gets the torch.device based on the current environment.
copy_data_to_device Function that recursively copies data to a torch.device.
record_data_in_stream Records the tensor element on certain streams, to avoid memory from being reused for another tensor.
get_nvidia_smi_gpu_stats Get GPU stats from nvidia smi.
get_psutil_cpu_stats Get CPU process stats using psutil.
maybe_enable_tf32 Conditionally sets the precision of float32 matrix multiplications and conv operations.

Distributed Utils

PGWrapper A wrapper around ProcessGroup that allows collectives to be issued in a consistent fashion regardless of the following scenarios:
get_global_rank Get rank using torch.distributed if available.
get_local_rank Get rank using the LOCAL_RANK environment variable, if populated: Defaults to 0 if LOCAL_RANK is not set.
get_world_size Get world size using torch.distributed if available.
barrier Add a synchronization point across all processes when using distributed.
destroy_process_group Destroy the global process group, if one is already initialized.
get_process_group_backend_from_device Function that gets the default process group backend from the device.
get_file_init_method Gets init method for the TCP protocol for the distributed environment.
get_tcp_init_method Gets init method for the TCP protocol for the distributed environment.
all_gather_tensors Function to gather tensors from several distributed processes onto a list that is broadcasted to all processes.
rank_zero_fn Function that can be used as a decorator to enable a function to be called on global rank 0 only.
revert_sync_batchnorm Helper function to convert all torch.nn.SyncBatchNorm layers in the module to BatchNorm*D layers.
sync_bool Utility to synchronize a boolean value across members of a provided process group.

Early Stop Checker

EarlyStopChecker Monitor a metric and signal if execution should stop early.

Environment Utils

init_from_env Utility function that initializes the device and process group, if applicable.
seed Function that sets seed for pseudo-random number generators across commonly used libraries.

Flops Utils

FlopTensorDispatchMode A context manager to measure flops of a module.

Filesystem Spec Utils

get_filesystem Returns the appropriate filesystem to use when handling the given path.

Logger Utils

FileLogger Abstract file logger.
MetricLogger Abstract metric logger.
CSVLogger CSV file logger.
InMemoryLogger Simple logger that buffers data in-memory.
JSONLogger JSON file logger.
TensorBoardLogger Simple logger for TensorBoard.

Memory Utils

RSSProfiler A profiler that periodically measures RSS (resident set size) delta.
measure_rss_deltas A context manager that periodically measures RSS (resident set size) delta.

Module Summary Utils

ModuleSummary Summary of module and its submodules.
get_module_summary Generate a ModuleSummary object, then assign its values and generate submodule tree.
get_summary_table Generates a string summary_table, tabularizing the information in module_summary.
prune_module_summary Prune the module summaries that are deeper than max_depth in the module summary tree.

OOM Utils

is_out_of_cpu_memory Returns True if the exception is related to CPU OOM
is_out_of_cuda_memory Returns True if the exception is related to CUDA OOM
is_out_of_memory_error Returns True if an exception is due to an OOM based on error message
log_memory_snapshot Writes the memory snapshots to the provided output_dir.
attach_oom_observer Attaches a function to record the PyTorch memory snapshot when an out of memory error occurs.

Optimizer Utils

init_optim_state Initialize optimizer states by calling step() with zero grads.

Precision Utils

convert_precision_str_to_dtype Converts precision as a string to a torch.dtype

Prepare Module Utils

prepare_module Utility to move a module to device, set up parallelism, activation checkpointing and compile.
prepare_ddp Utility to move a module to device and wrap in DistributedDataParallel.
prepare_fsdp Utility to move a module to device and wrap in FullyShardedDataParallel.
convert_str_to_strategy Converts strategy as a string to a default instance of the Strategy dataclass.

Progress Utils

Progress Class to track progress during the loop.
estimated_steps_in_epoch Estimate the number of remaining steps for the current epoch.
estimated_steps_in_loop Estimate the total number of steps for the current loop.
estimated_steps_in_fit Estimate the total number of steps for fit run.

Rank Zero Log Utils

rank_zero_print Call print function only from rank 0.
rank_zero_debug Log debug message only from rank 0.
rank_zero_info Log info message only from rank 0.
rank_zero_warn Log warn message only from rank 0.
rank_zero_error Log error message only from rank 0.
rank_zero_critical Log critical message only from rank 0.


Stateful Defines the interface for checkpoint saving and loading.

Test Utils

get_pet_launch_config Initialize pet.LaunchConfig for single-node, multi-rank functions.
is_asan Determines if the Python interpreter is running with ASAN
is_tsan Determines if the Python interpreter is running with TSAN
skip_if_asan Skip test run if we are in ASAN mode.
spawn_multi_process Spawn single node, multi-rank function.

Timer Utils

log_elapsed_time Utility to measure and log elapsed time for a given event.
TimerProtocol Defines a Timer Protocol with time and reset methods and an attribute recorded_durations for storing timings.
FullSyncPeriodicTimer Measures time (resets if given interval elapses) on rank 0 and propagates result to other ranks.
BoundedTimer A Timer class which implements TimerProtocol and stores timings in a dictionary recorded_durations.
get_timer_summary Given a timer, generate a summary of all the recorded actions.
get_durations_histogram Computes a histogram of percentiles from the recorded durations passed in.
get_synced_durations_histogram Synchronizes the recorded durations across ranks.
get_synced_timer_histogram Synchronizes the input timer's recorded durations across ranks.
get_recorded_durations_table Helper function to generate recorded duration time in tabular format

TQDM Utils

create_progress_bar Constructs a tqdm() progress bar.
update_progress_bar Updates a progress bar to reflect the number of steps completed.
close_progress_bar Updates and closes a progress bar.

Version Utils

is_windows Is the current program running in the Windows operating system?
get_python_version Get the current runtime Python version as a Version.
get_torch_version Get the PyTorch version for the current runtime environment as a Version.

Misc Utils

days_to_secs Convert time from days to seconds
transfer_batch_norm_stats Transfer batch norm statistics between two same models


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources