get_memory_stats¶
- torchtune.training.get_memory_stats(device: device, reset_stats: bool = True) dict [source]¶
Computes a memory summary for the passed in device. If
reset_stats
isTrue
, this will also reset CUDA’s peak memory tracking. This is useful to get data around relative use of peak memory (e.g. peak memory during model init, during forward, etc) and optimize memory for individual sections of training.- Parameters:
device (torch.device) – Device to get memory summary for. Only CUDA devices are supported.
reset_stats (bool) – Whether to reset CUDA’s peak memory tracking.
- Returns:
A dictionary containing the peak memory active, peak memory allocated, and peak memory reserved. This dict is useful for logging memory stats.
- Return type:
- Raises:
ValueError – If the passed-in device is not CUDA.