Shortcuts

Fit

Fit Entry Point

torchtnt.framework.fit.fit(unit: TrainUnit[TTrainData], train_dataloader: Iterable[TTrainData], eval_dataloader: Iterable[TEvalData], *, max_epochs: Optional[int] = None, max_steps: Optional[int] = None, max_train_steps_per_epoch: Optional[int] = None, max_eval_steps_per_epoch: Optional[int] = None, evaluate_every_n_steps: Optional[int] = None, evaluate_every_n_epochs: Optional[int] = 1, callbacks: Optional[List[Callback]] = None, timer: Optional[TimerProtocol] = None) None

The fit entry point interleaves training and evaluation loops. The fit entry point takes in an object which subclasses both TrainUnit and EvalUnit, train and eval dataloaders (any Iterables), optional arguments to modify loop execution, and runs the fit loop.

Parameters:
  • unit – an instance that subclasses both TrainUnit and EvalUnit, implementing train_step() and eval_step().
  • train_dataloader – dataloader to be used during training, which can be any iterable, including PyTorch DataLoader, DataLoader2, etc.
  • eval_dataloader – dataloader to be used during evaluation, which can be any iterable, including PyTorch DataLoader, DataLoader2, etc.
  • max_epochs – the max number of epochs to run for training. None means no limit (infinite training) unless stopped by max_steps.
  • max_steps – the max number of steps to run for training. None means no limit (infinite training) unless stopped by max_epochs.
  • max_train_steps_per_epoch – the max number of steps to run per epoch for training. None means train until train_dataloader is exhausted.
  • max_eval_steps_per_epoch – the max number of steps to run per epoch for evaluation. None means evaluate until eval_dataloader is exhausted.
  • evaluate_every_n_steps – how often to run the evaluation loop in terms of training steps.
  • evaluate_every_n_epochs – how often to run the evaluation loop in terms of training epochs.
  • callbacks – an optional list of callbacks.
  • timer – an optional Timer which will be used to time key events (using a Timer with CUDA synchronization may degrade performance).

Below is an example of calling fit().

from torchtnt.framework.fit import fit

fit_unit = MyFitUnit(module=..., optimizer=..., lr_scheduler=...)
train_dataloader = torch.utils.data.DataLoader(...)
eval_dataloader = torch.utils.data.DataLoader(...)
fit(fit_unit, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_epochs=4)

Below is pseudocode of what the fit() entry point does.

set unit's tracked modules to train mode
call on_train_start on unit first and then callbacks
while training is not done:
    while epoch is not done:
        call on_train_epoch_start on unit first and then callbacks
        try:
            data = next(dataloader)
            call on_train_step_start on callbacks
            call train_step on unit
            increment step counter
            call on_train_step_end on callbacks
            if should evaluate after this step:
                run eval loops
        except StopIteration:
            break
    increment epoch counter
    call on_train_epoch_end on unit first and then callbacks
    if should evaluate after this epoch:
        run eval loop
call on_train_end on unit first and then callbacks

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources