Shortcuts

Lambda

class torchtnt.framework.callbacks.Lambda(*, on_exception: Optional[Callable[[State, Union[TrainUnit[TTrainData], EvalUnit[TEvalData], PredictUnit[TPredictData]], BaseException], None]] = None, on_train_start: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_train_epoch_start: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_train_step_start: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_train_step_end: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_train_epoch_end: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_train_end: Optional[Callable[[State, TrainUnit[TTrainData]], None]] = None, on_eval_start: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_eval_epoch_start: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_eval_step_start: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_eval_step_end: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_eval_epoch_end: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_eval_end: Optional[Callable[[State, EvalUnit[TEvalData]], None]] = None, on_predict_start: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None, on_predict_epoch_start: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None, on_predict_step_start: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None, on_predict_step_end: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None, on_predict_epoch_end: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None, on_predict_end: Optional[Callable[[State, PredictUnit[TPredictData]], None]] = None)

A callback that accepts functions run during the training, evaluation, and prediction loops.

Parameters:
  • on_exception – function to run when an exception occurs.
  • on_train_start – function to run when train starts.
  • on_train_epoch_start – function to run when each train epoch starts.
  • on_train_step_start – function to run when each train step starts.
  • on_train_step_end – function to run when each train step ends.
  • on_train_epoch_end – function to run when each train epoch ends.
  • on_train_end – function to run when train ends.
  • on_eval_start – function to run when eval starts.
  • on_eval_epoch_start – function to run when each eval epoch starts.
  • on_eval_step_start – function to run when each eval step starts.
  • on_eval_step_end – function to run when each eval step ends.
  • on_eval_epoch_end – function to run when each eval epoch ends.
  • on_eval_end – function to run when eval ends.
  • on_predict_start – function to run when predict starts.
  • on_predict_epoch_start – function to run when each predict epoch starts.
  • on_predict_step_start – function to run when each predict step starts.
  • on_predict_step_end – function to run when each predict step ends.
  • on_predict_epoch_end – function to run when each predict epoch ends.
  • on_predict_end – function to run when predict ends.

Examples:

from torchtnt.framework.callbacks import Lambda
from torchtnt.framework.evaluate import evaluate

dataloader = MyDataLoader()
unit = MyUnit()

def print_on_step_start(state, unit) -> None:
    print(f'starting eval step {unit.eval_progress.num_steps_completed}')


lambda_cb = Lambda(
    on_eval_start=lambda *args, print('starting eval'),
    on_eval_step_start=print_on_step_start,
)

evaluate(unit, dataloader, callbacks=[lambda_cb])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources