Fit¶
Fit Entry Point¶
-
torchtnt.framework.fit.
fit
(unit: TrainUnit[TTrainData], train_dataloader: Iterable[TTrainData], eval_dataloader: Iterable[TEvalData], *, max_epochs: Optional[int] = None, max_steps: Optional[int] = None, max_train_steps_per_epoch: Optional[int] = None, max_eval_steps_per_epoch: Optional[int] = None, evaluate_every_n_steps: Optional[int] = None, evaluate_every_n_epochs: Optional[int] = 1, callbacks: Optional[List[Callback]] = None, timer: Optional[TimerProtocol] = None) None ¶ The
fit
entry point interleaves training and evaluation loops. Thefit
entry point takes in an object which subclasses bothTrainUnit
andEvalUnit
, train and eval dataloaders (any Iterables), optional arguments to modify loop execution, and runs the fit loop.Parameters: - unit – an instance that subclasses both
TrainUnit
andEvalUnit
, implementingtrain_step()
andeval_step()
. - train_dataloader – dataloader to be used during training, which can be any iterable, including PyTorch DataLoader, DataLoader2, etc.
- eval_dataloader – dataloader to be used during evaluation, which can be any iterable, including PyTorch DataLoader, DataLoader2, etc.
- max_epochs – the max number of epochs to run for training.
None
means no limit (infinite training) unless stopped by max_steps. - max_steps – the max number of steps to run for training.
None
means no limit (infinite training) unless stopped by max_epochs. - max_train_steps_per_epoch – the max number of steps to run per epoch for training. None means train until
train_dataloader
is exhausted. - max_eval_steps_per_epoch – the max number of steps to run per epoch for evaluation. None means evaluate until
eval_dataloader
is exhausted. - evaluate_every_n_steps – how often to run the evaluation loop in terms of training steps.
- evaluate_every_n_epochs – how often to run the evaluation loop in terms of training epochs.
- callbacks – an optional list of callbacks.
- timer – an optional Timer which will be used to time key events (using a Timer with CUDA synchronization may degrade performance).
Below is an example of calling
fit()
.from torchtnt.framework.fit import fit fit_unit = MyFitUnit(module=..., optimizer=..., lr_scheduler=...) train_dataloader = torch.utils.data.DataLoader(...) eval_dataloader = torch.utils.data.DataLoader(...) fit(fit_unit, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_epochs=4)
Below is pseudocode of what the
fit()
entry point does.set unit's tracked modules to train mode call on_train_start on unit first and then callbacks while training is not done: while epoch is not done: call on_train_epoch_start on unit first and then callbacks try: data = next(dataloader) call on_train_step_start on callbacks call train_step on unit increment step counter call on_train_step_end on callbacks if should evaluate after this step: run eval loops except StopIteration: break increment epoch counter call on_train_epoch_end on unit first and then callbacks if should evaluate after this epoch: run eval loop call on_train_end on unit first and then callbacks
- unit – an instance that subclasses both