Utils¶
Training related utilities. These are independent of the framework and can be used as needed.
Data Utils¶
AllDatasetBatchesIterator |
AllDatasetBatchesIterator returns a dict containing batches from all dataloaders. |
CudaDataPrefetcher |
CudaDataPrefetcher prefetches batches and moves them to the device. |
InOrderIterator |
InOrderIterator returns all batches from a single dataset till it is exhausted and then moves to the next one. |
MultiDataLoader |
MultiDataLoader cycles through individual dataloaders passed to it. |
MultiIterator |
MultiIterator defines the iteration logic to get a batch, given batches from all individual dataloaders. |
RandomizedBatchSamplerIterator |
RandomizedBatchSamplerIterator randomly samples from each dataset using the provided weights. |
RoundRobinIterator |
RoundRobinIterator cycles over the dataloader one by one. |
Device Utils¶
get_device_from_env |
Function that gets the torch.device based on the current environment. |
copy_data_to_device |
Function that recursively copies data to a torch.device. |
record_data_in_stream |
Records the tensor element on certain streams, to avoid memory from being reused for another tensor. |
get_nvidia_smi_gpu_stats |
Get GPU stats from nvidia smi. |
get_psutil_cpu_stats |
Get CPU process stats using psutil. |
maybe_enable_tf32 |
Conditionally sets the precision of float32 matrix multiplications and conv operations. |
Distributed Utils¶
PGWrapper |
A wrapper around ProcessGroup that allows collectives to be issued in a consistent fashion regardless of the following scenarios: |
get_global_rank |
Get rank using torch.distributed if available. |
get_local_rank |
Get rank using the LOCAL_RANK environment variable, if populated: https://pytorch.org/docs/stable/elastic/run.html#environment-variables Defaults to 0 if LOCAL_RANK is not set. |
get_world_size |
Get world size using torch.distributed if available. |
barrier |
Add a synchronization point across all processes when using distributed. |
destroy_process_group |
Destroy the global process group, if one is already initialized. |
get_process_group_backend_from_device |
Function that gets the default process group backend from the device. |
get_file_init_method |
Gets init method for the TCP protocol for the distributed environment. |
get_tcp_init_method |
Gets init method for the TCP protocol for the distributed environment. |
all_gather_tensors |
Function to gather tensors from several distributed processes onto a list that is broadcasted to all processes. |
rank_zero_fn |
Function that can be used as a decorator to enable a function to be called on global rank 0 only. |
revert_sync_batchnorm |
Helper function to convert all torch.nn.SyncBatchNorm layers in the module to BatchNorm*D layers. |
sync_bool |
Utility to synchronize a boolean value across members of a provided process group. |
Early Stop Checker¶
EarlyStopChecker |
Monitor a metric and signal if execution should stop early. |
Environment Utils¶
init_from_env |
Utility function that initializes the device and process group, if applicable. |
seed |
Function that sets seed for pseudo-random number generators across commonly used libraries. |
Flops Utils¶
FlopTensorDispatchMode |
A context manager to measure flops of a module. |
Filesystem Spec Utils¶
get_filesystem |
Returns the appropriate filesystem to use when handling the given path. |
Logger Utils¶
FileLogger |
Abstract file logger. |
MetricLogger |
Abstract metric logger. |
CSVLogger |
CSV file logger. |
InMemoryLogger |
Simple logger that buffers data in-memory. |
JSONLogger |
JSON file logger. |
TensorBoardLogger |
Simple logger for TensorBoard. |
Memory Utils¶
RSSProfiler |
A profiler that periodically measures RSS (resident set size) delta. |
measure_rss_deltas |
A context manager that periodically measures RSS (resident set size) delta. |
Module Summary Utils¶
ModuleSummary |
Summary of module and its submodules. |
get_module_summary |
Generate a ModuleSummary object, then assign its values and generate submodule tree. |
get_summary_table |
Generates a string summary_table, tabularizing the information in module_summary. |
prune_module_summary |
Prune the module summaries that are deeper than max_depth in the module summary tree. |
OOM Utils¶
is_out_of_cpu_memory |
Returns True if the exception is related to CPU OOM |
is_out_of_cuda_memory |
Returns True if the exception is related to CUDA OOM |
is_out_of_memory_error |
Returns True if an exception is due to an OOM based on error message |
log_memory_snapshot |
Writes the memory snapshots to the provided output_dir . |
attach_oom_observer |
Attaches a function to record the PyTorch memory snapshot when an out of memory error occurs. |
Optimizer Utils¶
init_optim_state |
Initialize optimizer states by calling step() with zero grads. |
Precision Utils¶
convert_precision_str_to_dtype |
Converts precision as a string to a torch.dtype |
Prepare Module Utils¶
prepare_module |
Utility to move a module to device, set up parallelism, activation checkpointing and compile. |
prepare_ddp |
Utility to move a module to device and wrap in DistributedDataParallel. |
prepare_fsdp |
Utility to move a module to device and wrap in FullyShardedDataParallel. |
convert_str_to_strategy |
Converts strategy as a string to a default instance of the Strategy dataclass. |
Progress Utils¶
Progress |
Class to track progress during the loop. |
estimated_steps_in_epoch |
Estimate the number of remaining steps for the current epoch. |
estimated_steps_in_loop |
Estimate the total number of steps for the current loop. |
estimated_steps_in_fit |
Estimate the total number of steps for fit run. |
Rank Zero Log Utils¶
rank_zero_print |
Call print function only from rank 0. |
rank_zero_debug |
Log debug message only from rank 0. |
rank_zero_info |
Log info message only from rank 0. |
rank_zero_warn |
Log warn message only from rank 0. |
rank_zero_error |
Log error message only from rank 0. |
rank_zero_critical |
Log critical message only from rank 0. |
Test Utils¶
get_pet_launch_config |
Initialize pet.LaunchConfig for single-node, multi-rank functions. |
is_asan |
Determines if the Python interpreter is running with ASAN |
is_tsan |
Determines if the Python interpreter is running with TSAN |
skip_if_asan |
Skip test run if we are in ASAN mode. |
spawn_multi_process |
Spawn single node, multi-rank function. |
Timer Utils¶
log_elapsed_time |
Utility to measure and log elapsed time for a given event. |
TimerProtocol |
Defines a Timer Protocol with time and reset methods and an attribute recorded_durations for storing timings. |
Timer |
|
FullSyncPeriodicTimer |
Measures time (resets if given interval elapses) on rank 0 and propagates result to other ranks. |
BoundedTimer |
A Timer class which implements TimerProtocol and stores timings in a dictionary recorded_durations. |
get_timer_summary |
Given a timer, generate a summary of all the recorded actions. |
get_durations_histogram |
Computes a histogram of percentiles from the recorded durations passed in. |
get_synced_durations_histogram |
Synchronizes the recorded durations across ranks. |
get_synced_timer_histogram |
Synchronizes the input timer's recorded durations across ranks. |
get_recorded_durations_table |
Helper function to generate recorded duration time in tabular format |
TQDM Utils¶
create_progress_bar |
Constructs a tqdm() progress bar. |
update_progress_bar |
Updates a progress bar to reflect the number of steps completed. |
close_progress_bar |
Updates and closes a progress bar. |
Version Utils¶
is_windows |
Is the current program running in the Windows operating system? |
get_python_version |
Get the current runtime Python version as a Version. |
get_torch_version |
Get the PyTorch version for the current runtime environment as a Version. |
Misc Utils¶
days_to_secs |
Convert time from days to seconds |
transfer_batch_norm_stats |
Transfer batch norm statistics between two same models |