Unit¶
The Unit concept represents the primary place to organize your model code in TorchTNT. TorchTNT offers three different types of Unit classes for training, evaluation, and prediction. These interfaces are mutually exclusive and can be combined as needed, e.g. in the case of fitting (interleaving training and evaluation).
TrainUnit¶
-
class
torchtnt.framework.unit.
TrainUnit
¶ The TrainUnit is an interface that can be used to organize your training logic. The core of it is the
train_step
which is an abstract method where you can define the code you want to run each iteration of the dataloader.To use the TrainUnit, create a class which subclasses TrainUnit. Then implement the
train_step
method on your class, and optionally implement any of the hooks, which allow you to control the behavior of the loop at different points.Below is a simple example of a user’s subclass of TrainUnit that implements a basic
train_step
, and theon_train_epoch_end
hook.from torchtnt.framework.unit import TrainUnit Batch = Tuple[torch.tensor, torch.tensor] # specify type of the data in each batch of the dataloader to allow for typechecking class MyTrainUnit(TrainUnit[Batch]): def __init__( self, module: torch.nn.Module, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, ): super().__init__() self.module = module self.optimizer = optimizer self.lr_scheduler = lr_scheduler def train_step(self, state: State, data: Batch) -> None: inputs, targets = data outputs = self.module(inputs) loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets) loss.backward() self.optimizer.step() self.optimizer.zero_grad() def on_train_epoch_end(self, state: State) -> None: # step the learning rate scheduler self.lr_scheduler.step() train_unit = MyTrainUnit(module=..., optimizer=..., lr_scheduler=...)
-
on_train_end
(state: State) None ¶ Hook called after training ends.
Parameters: state – a State
object containing metadata about the training run.
-
on_train_epoch_end
(state: State) None ¶ Hook called after a train epoch ends.
Parameters: state – a State
object containing metadata about the training run.
-
on_train_epoch_start
(state: State) None ¶ Hook called before a train epoch starts.
Parameters: state – a State
object containing metadata about the training run.
-
on_train_start
(state: State) None ¶ Hook called before training starts.
Parameters: state – a State
object containing metadata about the training run.
-
abstract
train_step
(state: State, data: TTrainData) Any ¶ Core required method for user to implement. This method will be called at each iteration of the train dataloader, and can return any data the user wishes.
Parameters: - state – a
State
object containing metadata about the training run. - data – one batch of training data.
- state – a
-
EvalUnit¶
-
class
torchtnt.framework.unit.
EvalUnit
¶ The EvalUnit is an interface that can be used to organize your evaluation logic. The core of it is the
eval_step
which is an abstract method where you can define the code you want to run each iteration of the dataloader.To use the EvalUnit, create a class which subclasses
EvalUnit
. Then implement theeval_step
method on your class, and then you can optionally implement any of the hooks which allow you to control the behavior of the loop at different points. Below is a simple example of a user’s subclass ofEvalUnit
that implements a basiceval_step
.from torchtnt.framework.unit import EvalUnit Batch = Tuple[torch.tensor, torch.tensor] # specify type of the data in each batch of the dataloader to allow for typechecking class MyEvalUnit(EvalUnit[Batch]): def __init__( self, module: torch.nn.Module, ): super().__init__() self.module = module def eval_step(self, state: State, data: Batch) -> None: inputs, targets = data outputs = self.module(inputs) loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets) eval_unit = MyEvalUnit(module=...)
-
abstract
eval_step
(state: State, data: TEvalData) Any ¶ Core required method for user to implement. This method will be called at each iteration of the eval dataloader, and can return any data the user wishes. Optionally can be decorated with
@torch.inference_mode()
for improved performance.Parameters: - state – a
State
object containing metadata about the evaluation run. - data – one batch of evaluation data.
- state – a
-
on_eval_end
(state: State) None ¶ Hook called after evaluation ends.
Parameters: state – a State
object containing metadata about the evaluation run.
-
on_eval_epoch_end
(state: State) None ¶ Hook called after an eval epoch ends.
Parameters: state – a State
object containing metadata about the evaluation run.
-
abstract
PredictUnit¶
-
class
torchtnt.framework.unit.
PredictUnit
¶ The PredictUnit is an interface that can be used to organize your prediction logic. The core of it is the
predict_step
which is an abstract method where you can define the code you want to run each iteration of the dataloader.To use the PredictUnit, create a class which subclasses
PredictUnit
. Then implement thepredict_step
method on your class, and then you can optionally implement any of the hooks which allow you to control the behavior of the loop at different points. Below is a simple example of a user’s subclass ofPredictUnit
that implements a basicpredict_step
.from torchtnt.framework.unit import PredictUnit Batch = Tuple[torch.tensor, torch.tensor] # specify type of the data in each batch of the dataloader to allow for typechecking class MyPredictUnit(PredictUnit[Batch]): def __init__( self, module: torch.nn.Module, ): super().__init__() self.module = module def predict_step(self, state: State, data: Batch) -> torch.tensor: inputs, targets = data outputs = self.module(inputs) return outputs predict_unit = MyPredictUnit(module=...)
-
on_predict_end
(state: State) None ¶ Hook called after prediction ends.
Parameters: state – a State
object containing metadata about the prediction run.
-
on_predict_epoch_end
(state: State) None ¶ Hook called after a predict epoch ends.
Parameters: state – a State
object containing metadata about the prediction run.
-
on_predict_epoch_start
(state: State) None ¶ Hook called before a predict epoch starts.
Parameters: state – a State
object containing metadata about the prediction run.
-
on_predict_start
(state: State) None ¶ Hook called before prediction starts.
Parameters: state – a State
object containing metadata about the prediction run.
-
abstract
predict_step
(state: State, data: TPredictData) Any ¶ Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with
@torch.inference_mode()
for improved performance.Parameters: - state – a
State
object containing metadata about the prediction run. - data – one batch of prediction data.
- state – a
-
Combining Multiple Units¶
In some cases, it is convenient to implement multiple Unit interfaces under the same class, e.g. if you plan to use your class to run several different phases; for example, running training and then prediction, or running training and evaluation interleaved (referred to as fitting). Here is an example of a unit which extends TrainUnit, EvalUnit, and PredictUnit.
from torchtnt.framework.unit import TrainUnit, EvalUnit, PredictUnit
Batch = Tuple[torch.tensor, torch.tensor]
class MyUnit(TrainUnit[Batch], EvalUnit[Batch], PredictUnit[Batch]):
def __init__(
self,
module: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
):
super().__init__()
self.module = module
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
def train_step(self, state: State, data: Batch) -> None:
inputs, targets = data
outputs = self.module(inputs)
loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
def eval_step(self, state: State, data: Batch) -> None:
inputs, targets = data
outputs = self.module(inputs)
loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets)
def predict_step(self, state: State, data: Batch) -> torch.tensor:
inputs, targets = data
outputs = self.module(inputs)
return outputs
def on_train_epoch_end(self, state: State) -> None:
# step the learning rate scheduler
self.lr_scheduler.step()
my_unit = MyUnit(module=..., optimizer=..., lr_scheduler=...)