

The Unit concept represents the primary place to organize your model code in TorchTNT. TorchTNT offers three different types of Unit classes for training, evaluation, and prediction. These interfaces are mutually exclusive and can be combined as needed, e.g. in the case of fitting (interleaving training and evaluation).


class torchtnt.framework.unit.TrainUnit

The TrainUnit is an interface that can be used to organize your training logic. The core of it is the train_step which is an abstract method where you can define the code you want to run each iteration of the dataloader.

To use the TrainUnit, create a class which subclasses TrainUnit. Then implement the train_step method on your class, and optionally implement any of the hooks, which allow you to control the behavior of the loop at different points.

Below is a simple example of a user’s subclass of TrainUnit that implements a basic train_step, and the on_train_epoch_end hook.

from torchtnt.framework.unit import TrainUnit

Batch = Tuple[torch.tensor, torch.tensor]
# specify type of the data in each batch of the dataloader to allow for typechecking

class MyTrainUnit(TrainUnit[Batch]):
    def __init__(
        module: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
        self.module = module
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

    def train_step(self, state: State, data: Batch) -> None:
        inputs, targets = data
        outputs = self.module(inputs)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets)


    def on_train_epoch_end(self, state: State) -> None:
        # step the learning rate scheduler

train_unit = MyTrainUnit(module=..., optimizer=..., lr_scheduler=...)
on_train_end(state: State) None

Hook called after training ends.

Parameters:state – a State object containing metadata about the training run.
on_train_epoch_end(state: State) None

Hook called after a train epoch ends.

Parameters:state – a State object containing metadata about the training run.
on_train_epoch_start(state: State) None

Hook called before a train epoch starts.

Parameters:state – a State object containing metadata about the training run.
on_train_start(state: State) None

Hook called before training starts.

Parameters:state – a State object containing metadata about the training run.
abstract train_step(state: State, data: TTrainData) Any

Core required method for user to implement. This method will be called at each iteration of the train dataloader, and can return any data the user wishes.

  • state – a State object containing metadata about the training run.
  • data – one batch of training data.


class torchtnt.framework.unit.EvalUnit

The EvalUnit is an interface that can be used to organize your evaluation logic. The core of it is the eval_step which is an abstract method where you can define the code you want to run each iteration of the dataloader.

To use the EvalUnit, create a class which subclasses EvalUnit. Then implement the eval_step method on your class, and then you can optionally implement any of the hooks which allow you to control the behavior of the loop at different points. Below is a simple example of a user’s subclass of EvalUnit that implements a basic eval_step.

from torchtnt.framework.unit import EvalUnit

Batch = Tuple[torch.tensor, torch.tensor]
# specify type of the data in each batch of the dataloader to allow for typechecking

class MyEvalUnit(EvalUnit[Batch]):
    def __init__(
        module: torch.nn.Module,
        self.module = module

    def eval_step(self, state: State, data: Batch) -> None:
        inputs, targets = data
        outputs = self.module(inputs)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets)

eval_unit = MyEvalUnit(module=...)
abstract eval_step(state: State, data: TEvalData) Any

Core required method for user to implement. This method will be called at each iteration of the eval dataloader, and can return any data the user wishes. Optionally can be decorated with @torch.inference_mode() for improved performance.

  • state – a State object containing metadata about the evaluation run.
  • data – one batch of evaluation data.
on_eval_end(state: State) None

Hook called after evaluation ends.

Parameters:state – a State object containing metadata about the evaluation run.
on_eval_epoch_end(state: State) None

Hook called after an eval epoch ends.

Parameters:state – a State object containing metadata about the evaluation run.
on_eval_epoch_start(state: State) None

Hook called before a new eval epoch starts.

Parameters:state – a State object containing metadata about the evaluation run.
on_eval_start(state: State) None

Hook called before evaluation starts.

Parameters:state – a State object containing metadata about the evaluation run.


class torchtnt.framework.unit.PredictUnit

The PredictUnit is an interface that can be used to organize your prediction logic. The core of it is the predict_step which is an abstract method where you can define the code you want to run each iteration of the dataloader.

To use the PredictUnit, create a class which subclasses PredictUnit. Then implement the predict_step method on your class, and then you can optionally implement any of the hooks which allow you to control the behavior of the loop at different points. Below is a simple example of a user’s subclass of PredictUnit that implements a basic predict_step.

from torchtnt.framework.unit import PredictUnit

Batch = Tuple[torch.tensor, torch.tensor]
# specify type of the data in each batch of the dataloader to allow for typechecking

class MyPredictUnit(PredictUnit[Batch]):
    def __init__(
        module: torch.nn.Module,
        self.module = module

    def predict_step(self, state: State, data: Batch) -> torch.tensor:
        inputs, targets = data
        outputs = self.module(inputs)
        return outputs

predict_unit = MyPredictUnit(module=...)
on_predict_end(state: State) None

Hook called after prediction ends.

Parameters:state – a State object containing metadata about the prediction run.
on_predict_epoch_end(state: State) None

Hook called after a predict epoch ends.

Parameters:state – a State object containing metadata about the prediction run.
on_predict_epoch_start(state: State) None

Hook called before a predict epoch starts.

Parameters:state – a State object containing metadata about the prediction run.
on_predict_start(state: State) None

Hook called before prediction starts.

Parameters:state – a State object containing metadata about the prediction run.
abstract predict_step(state: State, data: TPredictData) Any

Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with @torch.inference_mode() for improved performance.

  • state – a State object containing metadata about the prediction run.
  • data – one batch of prediction data.

Combining Multiple Units

In some cases, it is convenient to implement multiple Unit interfaces under the same class, e.g. if you plan to use your class to run several different phases; for example, running training and then prediction, or running training and evaluation interleaved (referred to as fitting). Here is an example of a unit which extends TrainUnit, EvalUnit, and PredictUnit.

from torchtnt.framework.unit import TrainUnit, EvalUnit, PredictUnit

Batch = Tuple[torch.tensor, torch.tensor]

class MyUnit(TrainUnit[Batch], EvalUnit[Batch], PredictUnit[Batch]):
    def __init__(
        module: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
        self.module = module
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

    def train_step(self, state: State, data: Batch) -> None:
        inputs, targets = data
        outputs = self.module(inputs)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets)


   def eval_step(self, state: State, data: Batch) -> None:
        inputs, targets = data
        outputs = self.module(inputs)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets)

   def predict_step(self, state: State, data: Batch) -> torch.tensor:
        inputs, targets = data
        outputs = self.module(inputs)
        return outputs

    def on_train_epoch_end(self, state: State) -> None:
       # step the learning rate scheduler

my_unit = MyUnit(module=..., optimizer=..., lr_scheduler=...)


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources