Predict Entry Point

torchtnt.framework.predict.predict(predict_unit: PredictUnit[TPredictData], predict_dataloader: Iterable[TPredictData], *, max_steps_per_epoch: Optional[int] = None, callbacks: Optional[List[Callback]] = None, timer: Optional[TimerProtocol] = None) None

The predict entry point takes in a PredictUnit object, a train dataloader (any Iterable), optional arguments to modify loop execution, and runs the prediction loop.

  • predict_unit – an instance of PredictUnit which implements predict_step.
  • predict_dataloader – dataloader to be used during prediction, which can be any iterable, including PyTorch DataLoader, DataLoader2, etc.
  • max_steps_per_epoch – the max number of steps to run per epoch. None means predict until the dataloader is exhausted.
  • callbacks – an optional list of Callback s.
  • timer – an optional Timer which will be used to time key events (using a Timer with CUDA synchronization may degrade performance).

Below is an example of calling predict().

from torchtnt.framework.predict import predict

predict_unit = MyPredictUnit(module=..., optimizer=..., lr_scheduler=...)
predict_dataloader =
predict(predict_unit, predict_dataloader, max_steps_per_epoch=20)

Below is pseudocode of what the predict() entry point does.

set unit's tracked modules to eval mode
call on_predict_start on unit first and then callbacks
while not done:
    call on_predict_epoch_start on unit first and then callbacks
        data = next(dataloader)
        call on_predict_step_start on callbacks
        call predict_step on unit
        increment step counter
        call on_predict_step_end on callbacks
    except StopIteration:
increment epoch counter
call on_predict_epoch_end on unit first and then callbacks
call on_predict_end on unit first and then callbacks


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources