Shortcuts

Train

Train Entry Point

torchtnt.framework.train.train(train_unit: TrainUnit[TTrainData], train_dataloader: Iterable[TTrainData], *, max_epochs: Optional[int] = None, max_steps: Optional[int] = None, max_steps_per_epoch: Optional[int] = None, callbacks: Optional[List[Callback]] = None, timer: Optional[TimerProtocol] = None) None

The train entry point takes in a TrainUnit object, a train dataloader (any Iterable), optional arguments to modify loop execution, and runs the training loop.

Parameters:
  • train_unit – an instance of TrainUnit which implements train_step.
  • train_dataloader – dataloader to be used during training, which can be any iterable, including PyTorch DataLoader, DataLoader2, etc.
  • max_epochs – the max number of epochs to run. None means no limit (infinite training) unless stopped by max_steps.
  • max_steps – the max number of steps to run. None means no limit (infinite training) unless stopped by max_epochs.
  • max_steps_per_epoch – the max number of steps to run per epoch. None means train until the dataloader is exhausted.
  • callbacks – an optional list of Callback s.
  • timer – an optional Timer which will be used to time key events (using a Timer with CUDA synchronization may degrade performance).

Below is an example of calling train().

from torchtnt.framework.train import train

train_unit = MyTrainUnit(module=..., optimizer=..., lr_scheduler=...)
train_dataloader = torch.utils.data.DataLoader(...)
train(train_unit, train_dataloader, max_epochs=4)

Below is pseudocode of what the train() entry point does.

set unit's tracked modules to train mode
call on_train_start on unit first and then callbacks
while training is not done:
    while epoch is not done:
        call on_train_epoch_start on unit first and then callbacks
        try:
            data = next(dataloader)
            call on_train_step_start on callbacks
            call train_step on unit
            increment step counter
            call on_train_step_end on callbacks
        except StopIteration:
            break
    increment epoch counter
    call on_train_epoch_end on unit first and then callbacks
call on_train_end on unit first and then callbacks

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources