

class torchtnt.framework.callback.Callback

A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates for such logic are ones that can be re-used across units. Callbacks are generally not intended for modeling code; this should go in your Unit. To write your own callback, subclass the Callback class and add your own code into the hooks.

Below is an example of a basic callback which prints a message at various points during execution.

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit

class PrintingCallback(Callback):
    def on_train_start(self, state: State, unit: TTrainUnit) -> None:
        print("Starting training")

    def on_train_end(self, state: State, unit: TTrainUnit) -> None:
        print("Ending training")

    def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
        print("Starting evaluation")

    def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
        print("Ending evaluation")

    def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
        print("Starting prediction")

    def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
        print("Ending prediction")

To use a callback, instantiate the class and pass it in the callbacks parameter to the train(), evaluate(), predict(), or fit() entry point.

printing_callback = PrintingCallback()
train(train_unit, train_dataloader, callbacks=[printing_callback])
property name: str

A distinct name per instance. This is useful for debugging, profiling, and checkpointing purposes.

on_eval_end(state: State, unit: EvalUnit[TEvalData]) None

Hook called after evaluation ends.

on_eval_epoch_end(state: State, unit: EvalUnit[TEvalData]) None

Hook called after an eval epoch ends.

on_eval_epoch_start(state: State, unit: EvalUnit[TEvalData]) None

Hook called before a new eval epoch starts.

on_eval_start(state: State, unit: EvalUnit[TEvalData]) None

Hook called before evaluation starts.

on_eval_step_end(state: State, unit: EvalUnit[TEvalData]) None

Hook called after an eval step ends.

on_eval_step_start(state: State, unit: EvalUnit[TEvalData]) None

Hook called before a new eval step starts.

on_exception(state: State, unit: Union[TrainUnit[TTrainData], EvalUnit[TEvalData], PredictUnit[TPredictData]], exc: BaseException) None

Hook called when an exception occurs.

on_predict_end(state: State, unit: PredictUnit[TPredictData]) None

Hook called after prediction ends.

on_predict_epoch_end(state: State, unit: PredictUnit[TPredictData]) None

Hook called after a predict epoch ends.

on_predict_epoch_start(state: State, unit: PredictUnit[TPredictData]) None

Hook called before a new predict epoch starts.

on_predict_start(state: State, unit: PredictUnit[TPredictData]) None

Hook called before prediction starts.

on_predict_step_end(state: State, unit: PredictUnit[TPredictData]) None

Hook called after a predict step ends.

on_predict_step_start(state: State, unit: PredictUnit[TPredictData]) None

Hook called before a new predict step starts.

on_train_end(state: State, unit: TrainUnit[TTrainData]) None

Hook called after training ends.

on_train_epoch_end(state: State, unit: TrainUnit[TTrainData]) None

Hook called after a train epoch ends.

on_train_epoch_start(state: State, unit: TrainUnit[TTrainData]) None

Hook called before a new train epoch starts.

on_train_start(state: State, unit: TrainUnit[TTrainData]) None

Hook called before training starts.

on_train_step_end(state: State, unit: TrainUnit[TTrainData]) None

Hook called after a train step ends.

on_train_step_start(state: State, unit: TrainUnit[TTrainData]) None

Hook called before a new train step starts.

Built-in callbacks

We offer several pre-written callbacks which are ready to be used out of the box:

BaseCSVWriter A callback to write prediction outputs to a CSV file.
GarbageCollector A callback that performs periodic synchronous garbage collection.
Lambda A callback that accepts functions run during the training, evaluation, and prediction loops.
LearningRateMonitor A callback which logs learning rate of tracked optimizers and learning rate schedulers.
ModuleSummary A callback which generates and logs a summary of the modules.
PyTorchProfiler A callback which profiles user code using PyTorch Profiler.
SystemResourcesMonitor A callback which logs system stats, including: - CPU usage - resident set size - GPU usage - cuda memory stats
TensorBoardParameterMonitor A callback which logs module parameters as histograms to TensorBoard.
IterationTimeLogger A callback which logs iteration times as scalars to TensorBoard.
TorchSnapshotSaver A callback which periodically saves the application state during training using TorchSnapshot.
TQDMProgressBar A callback for progress bar visualization in training, evaluation, and prediction.
TrainProgressMonitor A callback which logs training progress in terms of steps vs epochs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources