Predict¶
Predict Entry Point¶
-
torchtnt.framework.predict.
predict
(predict_unit: PredictUnit[TPredictData], predict_dataloader: Iterable[TPredictData], *, max_steps_per_epoch: Optional[int] = None, callbacks: Optional[List[Callback]] = None, timer: Optional[TimerProtocol] = None) None ¶ The
predict
entry point takes in aPredictUnit
object, a train dataloader (any Iterable), optional arguments to modify loop execution, and runs the prediction loop.Parameters: - predict_unit – an instance of
PredictUnit
which implements predict_step. - predict_dataloader – dataloader to be used during prediction, which can be any iterable, including PyTorch DataLoader, DataLoader2, etc.
- max_steps_per_epoch – the max number of steps to run per epoch. None means predict until the dataloader is exhausted.
- callbacks – an optional list of
Callback
s. - timer – an optional Timer which will be used to time key events (using a Timer with CUDA synchronization may degrade performance).
Below is an example of calling
predict()
.from torchtnt.framework.predict import predict predict_unit = MyPredictUnit(module=..., optimizer=..., lr_scheduler=...) predict_dataloader = torch.utils.data.DataLoader(...) predict(predict_unit, predict_dataloader, max_steps_per_epoch=20)
Below is pseudocode of what the
predict()
entry point does.set unit's tracked modules to eval mode call on_predict_start on unit first and then callbacks while not done: call on_predict_epoch_start on unit first and then callbacks try: data = next(dataloader) call on_predict_step_start on callbacks call predict_step on unit increment step counter call on_predict_step_end on callbacks except StopIteration: break increment epoch counter call on_predict_epoch_end on unit first and then callbacks call on_predict_end on unit first and then callbacks
- predict_unit – an instance of