ModuleSummary¶
-
class
torchtnt.framework.callbacks.
ModuleSummary
(max_depth: ~typing.Optional[int] = None, process_fn: ~typing.Callable[[~typing.List[~torchtnt.utils.module_summary.ModuleSummary]], None] = <function _log_module_summary_tables>, module_inputs: ~typing.Optional[~typing.MutableMapping[str, ~typing.Tuple[~typing.Tuple[~typing.Any, ...], ~typing.Dict[str, ~typing.Any]]]] = None)¶ A callback which generates and logs a summary of the modules.
Parameters: - max_depth – The maximum depth of module summaries to keep.
- process_fn – Function to print the module summaries. Default is to log all module summary tables.
- module_inputs – A mapping from module name to (args, kwargs) for that module. Useful when wanting FLOPS, activation sizes, etc.
Raises: RuntimeError – If torcheval is not installed.
-
on_predict_start
(state: State, unit: PredictUnit[TPredictData]) None ¶ Hook called before prediction starts.