Shortcuts

AutoUnit

class torchtnt.framework.auto_unit.AutoUnit(*args, **kwargs)

The AutoUnit is a convenience for users who are training with stochastic gradient descent and would like to have model optimization and data parallel replication handled for them. The AutoUnit subclasses TrainUnit, EvalUnit, and PredictUnit and implements the train_step, eval_step, and predict_step methods for the user.

For the train_step it runs:

  • forward pass and loss computation
  • backward pass
  • optimizer step

For the eval_step it only runs forward and loss computation.

For the predict_step it only runs forward computation.

To benefit from the AutoUnit, the user must subclass it and implement the compute_loss and configure_optimizers_and_lr_scheduler methods. Additionally, the AutoUnit offers these optional hooks:

  • on_train_step_end
  • on_eval_step_end
  • on_predict_step_end

Then use with the train(), evaluate(), fit(), or predict() entry point as normal.

For more advanced customization, directly use the TrainUnit, EvalUnit, and PredictUnit interfaces.

Parameters:

Note

Stochastic Weight Averaging is currently not supported with the FSDP strategy.

Note

Torch compile support is only available in PyTorch 2.0 or higher.

abstract compute_loss(state: State, data: TData) Tuple[Tensor, Any]

The user should implement this method with their loss computation. This will be called every train_step/eval_step.

Parameters:
  • state – a State object which is passed from the train_step/eval_step
  • data – a batch of data which is passed from the train_step/eval_step
Returns:

Tuple containing the loss and the output of the model

Note

The module’s forward pass must be run as part of this method.

abstract configure_optimizers_and_lr_scheduler(module: Module) Tuple[Optimizer, Optional[LRScheduler]]

The user should implement this method with their optimizer and learning rate scheduler construction code. This will be called upon initialization of the AutoUnit.

Parameters:module – the module with which to construct optimizer and lr_scheduler
Returns:A tuple containing optimizer and optionally the learning rate scheduler
eval_step(state: State, data: TData) Tuple[Tensor, Any]

Core required method for user to implement. This method will be called at each iteration of the eval dataloader, and can return any data the user wishes. Optionally can be decorated with @torch.inference_mode() for improved performance.

Parameters:
  • state – a State object containing metadata about the evaluation run.
  • data – one batch of evaluation data.
move_data_to_device(state: State, data: TData, non_blocking: bool) TData

The user can override this method with custom code to copy data to device. This will be called at the start of every train_step/eval_step. By default this uses the utility function copy_data_to_device().

If on GPU, this method will be called on a separate CUDA stream.

Parameters:
  • state – a State object which is passed from the train_step/eval_step
  • data – a batch of data which is passed from the train_step/eval_step
  • non_blocking – parameter to pass to torch.tensor.to
Returns:

A batch of data which is on the device

on_eval_step_end(state: State, data: TData, step: int, loss: Tensor, outputs: Any) None

This will be called at the end of every eval_step before returning. The user can implement this method with code to update and log their metrics, or do anything else.

Parameters:
  • state – a State object which is passed from the eval_step
  • data – a batch of data which is passed from the eval_step
  • step – how many steps have been completed (train_step s when running fit and eval_step s when running evaluation)
  • loss – the loss computed in the compute_loss function
  • outputs – the outputs of the model forward pass
on_predict_step_end(state: State, data: TData, step: int, outputs: Any) None

This will be called at the end of every predict_step before returning. The user can implement this method with code to update and log their metrics, or do anything else.

Parameters:
  • state – a State object which is passed from the predict_step
  • data – a batch of data which is passed from the predict_step
  • step – how many ``predict_step``s have been completed
  • outputs – the outputs of the model forward pass
on_train_end(state: State) None

Note that if using SWA and implementing on_train_end(), must call super().on_train_end().

on_train_epoch_end(state: State) None

Note: if overriding on_train_epoch_end, remember to call super().on_train_epoch_end()

on_train_step_end(state: State, data: TData, step: int, loss: Tensor, outputs: Any) None

This will be called at the end of every train_step before returning. The user can implement this method with code to update and log their metrics, or do anything else.

Parameters:
  • state – a State object which is passed from the train_step
  • data – a batch of data which is passed from the train_step
  • step – how many train_step s have been completed
  • loss – the loss computed in the compute_loss function
  • outputs – the outputs of the model forward pass
predict_step(state: State, data: TData) Any

Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with @torch.inference_mode() for improved performance.

Parameters:
  • state – a State object containing metadata about the prediction run.
  • data – one batch of prediction data.
train_step(state: State, data: Iterator[TData]) Tuple[Tensor, Any]

Core required method for user to implement. This method will be called at each iteration of the train dataloader, and can return any data the user wishes.

Parameters:
  • state – a State object containing metadata about the training run.
  • data – one batch of training data.
class torchtnt.framework.auto_unit.AutoPredictUnit(*, module: Module, device: Optional[device] = None, strategy: Optional[Union[Strategy, str]] = None, precision: Optional[Union[str, dtype]] = None, torch_compile_params: Optional[TorchCompileParams] = None, detect_anomaly: Optional[bool] = None)
move_data_to_device(state: State, data: TPredictData, non_blocking: bool) TPredictData

The user can override this method with custom code to copy data to device. This will be called at the start of every predict_step. By default this uses the utility function copy_data_to_device().

If on GPU, this method will be called on a separate CUDA stream.

Parameters:
  • state – a State object which is passed from the predict_step
  • data – a batch of data which is passed from the predict_step
  • non_blocking – parameter to pass to torch.tensor.to
Returns:

A batch of data which is on the device

on_predict_step_end(state: State, data: TPredictData, step: int, outputs: Any) None

This will be called at the end of every predict_step before returning. The user can implement this method with code to update and log their metrics, or do anything else.

Parameters:
  • state – a State object which is passed from the predict_step
  • data – a batch of data which is passed from the predict_step
  • step – how many predict_step s have been completed
  • outputs – the outputs of the model forward pass
predict_step(state: State, data: Iterator[TPredictData]) Any

Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with @torch.inference_mode() for improved performance.

Parameters:
  • state – a State object containing metadata about the prediction run.
  • data – one batch of prediction data.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources