Train¶
Train Entry Point¶
-
torchtnt.framework.train.
train
(train_unit: TrainUnit[TTrainData], train_dataloader: Iterable[TTrainData], *, max_epochs: Optional[int] = None, max_steps: Optional[int] = None, max_steps_per_epoch: Optional[int] = None, callbacks: Optional[List[Callback]] = None, timer: Optional[TimerProtocol] = None) None ¶ The
train
entry point takes in aTrainUnit
object, a train dataloader (any Iterable), optional arguments to modify loop execution, and runs the training loop.Parameters: - train_unit – an instance of
TrainUnit
which implements train_step. - train_dataloader – dataloader to be used during training, which can be any iterable, including PyTorch DataLoader, DataLoader2, etc.
- max_epochs – the max number of epochs to run.
None
means no limit (infinite training) unless stopped by max_steps. - max_steps – the max number of steps to run.
None
means no limit (infinite training) unless stopped by max_epochs. - max_steps_per_epoch – the max number of steps to run per epoch. None means train until the dataloader is exhausted.
- callbacks – an optional list of
Callback
s. - timer – an optional Timer which will be used to time key events (using a Timer with CUDA synchronization may degrade performance).
Below is an example of calling
train()
.from torchtnt.framework.train import train train_unit = MyTrainUnit(module=..., optimizer=..., lr_scheduler=...) train_dataloader = torch.utils.data.DataLoader(...) train(train_unit, train_dataloader, max_epochs=4)
Below is pseudocode of what the
train()
entry point does.set unit's tracked modules to train mode call on_train_start on unit first and then callbacks while training is not done: while epoch is not done: call on_train_epoch_start on unit first and then callbacks try: data = next(dataloader) call on_train_step_start on callbacks call train_step on unit increment step counter call on_train_step_end on callbacks except StopIteration: break increment epoch counter call on_train_epoch_end on unit first and then callbacks call on_train_end on unit first and then callbacks
- train_unit – an instance of