Callbacks¶
-
class
torchtnt.framework.callback.
Callback
¶ A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates for such logic are ones that can be re-used across units. Callbacks are generally not intended for modeling code; this should go in your Unit. To write your own callback, subclass the Callback class and add your own code into the hooks.
Below is an example of a basic callback which prints a message at various points during execution.
from torchtnt.framework.callback import Callback from torchtnt.framework.state import State from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit class PrintingCallback(Callback): def on_train_start(self, state: State, unit: TTrainUnit) -> None: print("Starting training") def on_train_end(self, state: State, unit: TTrainUnit) -> None: print("Ending training") def on_eval_start(self, state: State, unit: TEvalUnit) -> None: print("Starting evaluation") def on_eval_end(self, state: State, unit: TEvalUnit) -> None: print("Ending evaluation") def on_predict_start(self, state: State, unit: TPredictUnit) -> None: print("Starting prediction") def on_predict_end(self, state: State, unit: TPredictUnit) -> None: print("Ending prediction")
To use a callback, instantiate the class and pass it in the
callbacks
parameter to thetrain()
,evaluate()
,predict()
, orfit()
entry point.printing_callback = PrintingCallback() train(train_unit, train_dataloader, callbacks=[printing_callback])
-
property
name
: str¶ A distinct name per instance. This is useful for debugging, profiling, and checkpointing purposes.
-
on_eval_epoch_end
(state: State, unit: EvalUnit[TEvalData]) None ¶ Hook called after an eval epoch ends.
-
on_eval_epoch_start
(state: State, unit: EvalUnit[TEvalData]) None ¶ Hook called before a new eval epoch starts.
-
on_eval_step_end
(state: State, unit: EvalUnit[TEvalData]) None ¶ Hook called after an eval step ends.
-
on_eval_step_start
(state: State, unit: EvalUnit[TEvalData]) None ¶ Hook called before a new eval step starts.
-
on_exception
(state: State, unit: Union[TrainUnit[TTrainData], EvalUnit[TEvalData], PredictUnit[TPredictData]], exc: BaseException) None ¶ Hook called when an exception occurs.
-
on_predict_end
(state: State, unit: PredictUnit[TPredictData]) None ¶ Hook called after prediction ends.
-
on_predict_epoch_end
(state: State, unit: PredictUnit[TPredictData]) None ¶ Hook called after a predict epoch ends.
-
on_predict_epoch_start
(state: State, unit: PredictUnit[TPredictData]) None ¶ Hook called before a new predict epoch starts.
-
on_predict_start
(state: State, unit: PredictUnit[TPredictData]) None ¶ Hook called before prediction starts.
-
on_predict_step_end
(state: State, unit: PredictUnit[TPredictData]) None ¶ Hook called after a predict step ends.
-
on_predict_step_start
(state: State, unit: PredictUnit[TPredictData]) None ¶ Hook called before a new predict step starts.
-
on_train_epoch_end
(state: State, unit: TrainUnit[TTrainData]) None ¶ Hook called after a train epoch ends.
-
on_train_epoch_start
(state: State, unit: TrainUnit[TTrainData]) None ¶ Hook called before a new train epoch starts.
-
on_train_start
(state: State, unit: TrainUnit[TTrainData]) None ¶ Hook called before training starts.
-
property
Built-in callbacks¶
We offer several pre-written callbacks which are ready to be used out of the box:
BaseCSVWriter |
A callback to write prediction outputs to a CSV file. |
GarbageCollector |
A callback that performs periodic synchronous garbage collection. |
Lambda |
A callback that accepts functions run during the training, evaluation, and prediction loops. |
LearningRateMonitor |
A callback which logs learning rate of tracked optimizers and learning rate schedulers. |
ModuleSummary |
A callback which generates and logs a summary of the modules. |
PyTorchProfiler |
A callback which profiles user code using PyTorch Profiler. |
SystemResourcesMonitor |
A callback which logs system stats, including: - CPU usage - resident set size - GPU usage - cuda memory stats |
TensorBoardParameterMonitor |
A callback which logs module parameters as histograms to TensorBoard. |
IterationTimeLogger |
A callback which logs iteration times as scalars to TensorBoard. |
TorchSnapshotSaver |
A callback which periodically saves the application state during training using TorchSnapshot. |
TQDMProgressBar |
A callback for progress bar visualization in training, evaluation, and prediction. |
TrainProgressMonitor |
A callback which logs training progress in terms of steps vs epochs. |