AutoUnit¶
-
class
torchtnt.framework.auto_unit.
AutoUnit
(*args, **kwargs)¶ The AutoUnit is a convenience for users who are training with stochastic gradient descent and would like to have model optimization and data parallel replication handled for them. The AutoUnit subclasses
TrainUnit
,EvalUnit
, andPredictUnit
and implements thetrain_step
,eval_step
, andpredict_step
methods for the user.For the
train_step
it runs:- forward pass and loss computation
- backward pass
- optimizer step
For the
eval_step
it only runs forward and loss computation.For the
predict_step
it only runs forward computation.To benefit from the AutoUnit, the user must subclass it and implement the
compute_loss
andconfigure_optimizers_and_lr_scheduler
methods. Additionally, the AutoUnit offers these optional hooks:on_train_step_end
on_eval_step_end
on_predict_step_end
Then use with the
train()
,evaluate()
,fit()
, orpredict()
entry point as normal.For more advanced customization, directly use the
TrainUnit
,EvalUnit
, andPredictUnit
interfaces.Parameters: - module – module to be used during training/evaluation.
- device – the device to be used.
- strategy – the data parallelization strategy to be used. if a string, must be one of
ddp
orfsdp
. - step_lr_interval – whether to step lr_scheduler every step or every epoch. Defaults to every epoch.
- precision – the precision to use in training/evaluation, as either a string or a torch.dtype.
- gradient_accumulation_steps – how many batches to accumulate gradients over.
- detect_anomaly – whether to enable anomaly detection for the autograd engine https://pytorch.org/docs/stable/autograd.html#anomaly-detection
- clip_grad_norm – max norm of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
- clip_grad_value – max value of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html
- swa_params – params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging
- torch_compile_params – params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html
- activation_checkpoint_params – params for enabling activation checkpointing
- training – if True, the optimizer and optionally LR scheduler will be created after the class is initialized.
Note
Stochastic Weight Averaging is currently not supported with the FSDP strategy.
Note
Torch compile support is only available in PyTorch 2.0 or higher.
-
abstract
compute_loss
(state: State, data: TData) Tuple[Tensor, Any] ¶ The user should implement this method with their loss computation. This will be called every
train_step
/eval_step
.Parameters: - state – a State object which is passed from the
train_step
/eval_step
- data – a batch of data which is passed from the
train_step
/eval_step
Returns: Tuple containing the loss and the output of the model
Note
The module’s forward pass must be run as part of this method.
- state – a State object which is passed from the
-
abstract
configure_optimizers_and_lr_scheduler
(module: Module) Tuple[Optimizer, Optional[LRScheduler]] ¶ The user should implement this method with their optimizer and learning rate scheduler construction code. This will be called upon initialization of the AutoUnit.
Parameters: module – the module with which to construct optimizer and lr_scheduler Returns: A tuple containing optimizer and optionally the learning rate scheduler
-
eval_step
(state: State, data: TData) Tuple[Tensor, Any] ¶ Core required method for user to implement. This method will be called at each iteration of the eval dataloader, and can return any data the user wishes. Optionally can be decorated with
@torch.inference_mode()
for improved performance.Parameters: - state – a
State
object containing metadata about the evaluation run. - data – one batch of evaluation data.
- state – a
-
move_data_to_device
(state: State, data: TData, non_blocking: bool) TData ¶ The user can override this method with custom code to copy data to device. This will be called at the start of every
train_step
/eval_step
. By default this uses the utility functioncopy_data_to_device()
.If on GPU, this method will be called on a separate CUDA stream.
Parameters: - state – a State object which is passed from the
train_step
/eval_step
- data – a batch of data which is passed from the
train_step
/eval_step
- non_blocking – parameter to pass to
torch.tensor.to
Returns: A batch of data which is on the device
- state – a State object which is passed from the
-
on_eval_step_end
(state: State, data: TData, step: int, loss: Tensor, outputs: Any) None ¶ This will be called at the end of every
eval_step
before returning. The user can implement this method with code to update and log their metrics, or do anything else.Parameters: - state – a State object which is passed from the
eval_step
- data – a batch of data which is passed from the
eval_step
- step – how many steps have been completed (
train_step
s when running fit andeval_step
s when running evaluation) - loss – the loss computed in the
compute_loss
function - outputs – the outputs of the model forward pass
- state – a State object which is passed from the
-
on_predict_step_end
(state: State, data: TData, step: int, outputs: Any) None ¶ This will be called at the end of every
predict_step
before returning. The user can implement this method with code to update and log their metrics, or do anything else.Parameters: - state – a State object which is passed from the
predict_step
- data – a batch of data which is passed from the
predict_step
- step – how many ``predict_step``s have been completed
- outputs – the outputs of the model forward pass
- state – a State object which is passed from the
-
on_train_end
(state: State) None ¶ Note that if using SWA and implementing on_train_end(), must call super().on_train_end().
-
on_train_epoch_end
(state: State) None ¶ Note: if overriding
on_train_epoch_end
, remember to callsuper().on_train_epoch_end()
-
on_train_step_end
(state: State, data: TData, step: int, loss: Tensor, outputs: Any) None ¶ This will be called at the end of every
train_step
before returning. The user can implement this method with code to update and log their metrics, or do anything else.Parameters: - state – a State object which is passed from the
train_step
- data – a batch of data which is passed from the
train_step
- step – how many
train_step
s have been completed - loss – the loss computed in the
compute_loss
function - outputs – the outputs of the model forward pass
- state – a State object which is passed from the
-
predict_step
(state: State, data: TData) Any ¶ Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with
@torch.inference_mode()
for improved performance.Parameters: - state – a
State
object containing metadata about the prediction run. - data – one batch of prediction data.
- state – a
-
train_step
(state: State, data: Iterator[TData]) Tuple[Tensor, Any] ¶ Core required method for user to implement. This method will be called at each iteration of the train dataloader, and can return any data the user wishes.
Parameters: - state – a
State
object containing metadata about the training run. - data – one batch of training data.
- state – a
-
class
torchtnt.framework.auto_unit.
AutoPredictUnit
(*, module: Module, device: Optional[device] = None, strategy: Optional[Union[Strategy, str]] = None, precision: Optional[Union[str, dtype]] = None, torch_compile_params: Optional[TorchCompileParams] = None, detect_anomaly: Optional[bool] = None)¶ -
move_data_to_device
(state: State, data: TPredictData, non_blocking: bool) TPredictData ¶ The user can override this method with custom code to copy data to device. This will be called at the start of every
predict_step
. By default this uses the utility functioncopy_data_to_device()
.If on GPU, this method will be called on a separate CUDA stream.
Parameters: - state – a State object which is passed from the
predict_step
- data – a batch of data which is passed from the
predict_step
- non_blocking – parameter to pass to
torch.tensor.to
Returns: A batch of data which is on the device
- state – a State object which is passed from the
-
on_predict_step_end
(state: State, data: TPredictData, step: int, outputs: Any) None ¶ This will be called at the end of every
predict_step
before returning. The user can implement this method with code to update and log their metrics, or do anything else.Parameters: - state – a State object which is passed from the
predict_step
- data – a batch of data which is passed from the
predict_step
- step – how many
predict_step
s have been completed - outputs – the outputs of the model forward pass
- state – a State object which is passed from the
-
predict_step
(state: State, data: Iterator[TPredictData]) Any ¶ Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with
@torch.inference_mode()
for improved performance.Parameters: - state – a
State
object containing metadata about the prediction run. - data – one batch of prediction data.
- state – a
-