Lambda¶
-
class
torchtnt.framework.callbacks.
Lambda
(*, on_exception: Optional[Callable[[State, Union[TrainUnit[TTrainData], EvalUnit[TEvalData], PredictUnit[TPredictData]], BaseException], None]] = None, on_train_start: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_train_epoch_start: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_train_step_start: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_train_step_end: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_train_epoch_end: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_train_end: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_eval_start: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_eval_epoch_start: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_eval_step_start: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_eval_step_end: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_eval_epoch_end: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_eval_end: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_predict_start: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None, on_predict_epoch_start: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None, on_predict_step_start: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None, on_predict_step_end: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None, on_predict_epoch_end: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None, on_predict_end: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None)¶ A callback that accepts functions run during the training, evaluation, and prediction loops.
Parameters: - on_exception – function to run when an exception occurs.
- on_train_start – function to run when train starts.
- on_train_epoch_start – function to run when each train epoch starts.
- on_train_step_start – function to run when each train step starts.
- on_train_step_end – function to run when each train step ends.
- on_train_epoch_end – function to run when each train epoch ends.
- on_train_end – function to run when train ends.
- on_eval_start – function to run when eval starts.
- on_eval_epoch_start – function to run when each eval epoch starts.
- on_eval_step_start – function to run when each eval step starts.
- on_eval_step_end – function to run when each eval step ends.
- on_eval_epoch_end – function to run when each eval epoch ends.
- on_eval_end – function to run when eval ends.
- on_predict_start – function to run when predict starts.
- on_predict_epoch_start – function to run when each predict epoch starts.
- on_predict_step_start – function to run when each predict step starts.
- on_predict_step_end – function to run when each predict step ends.
- on_predict_epoch_end – function to run when each predict epoch ends.
- on_predict_end – function to run when predict ends.
Examples:
from torchtnt.framework.callbacks import Lambda from torchtnt.framework.evaluate import evaluate dataloader = MyDataLoader() unit = MyUnit() def print_on_step_start(state, unit) -> None: print(f'starting eval step {unit.eval_progress.num_steps_completed}') lambda_cb = Lambda( on_eval_start=lambda *args, print('starting eval'), on_eval_step_start=print_on_step_start, ) evaluate(unit, dataloader, callbacks=[lambda_cb])