Evaluate¶
Evaluate Entry Point¶
-
torchtnt.framework.evaluate.
evaluate
(eval_unit: EvalUnit[TEvalData], eval_dataloader: Iterable[TEvalData], *, max_steps_per_epoch: Optional[int] = None, callbacks: Optional[List[Callback]] = None, timer: Optional[TimerProtocol] = None) None ¶ The
evaluate
entry point takes in aEvalUnit
object, a train dataloader (any Iterable), optional arguments to modify loop execution, and runs the evaluation loop.Parameters: - eval_unit – an instance of
EvalUnit
which implements eval_step. - eval_dataloader – dataloader to be used during evaluation, which can be any iterable, including PyTorch DataLoader, DataLoader2, etc.
- max_steps_per_epoch – the max number of steps to run per epoch. None means evaluate until the dataloader is exhausted.
- callbacks – an optional list of
Callback
s. - timer – an optional Timer which will be used to time key events (using a Timer with CUDA synchronization may degrade performance).
Below is an example of calling
evaluate()
.from torchtnt.framework.evaluate import evaluate eval_unit = MyEvalUnit(module=..., optimizer=..., lr_scheduler=...) eval_dataloader = torch.utils.data.DataLoader(...) evaluate(eval_unit, eval_dataloader, max_steps_per_epoch=20)
Below is pseudocode of what the
evaluate()
entry point does.set unit's tracked modules to eval mode call on_eval_start on unit first and then callbacks while not done: call on_eval_epoch_start on unit first and then callbacks try: data = next(dataloader) call on_eval_step_start on callbacks call eval_step on unit increment step counter call on_eval_step_end on callbacks except StopIteration: break increment epoch counter call on_eval_epoch_end on unit first and then callbacks call on_eval_end on unit first and then callbacks
- eval_unit – an instance of