Shortcuts

supervised_training_step_apex#

ignite.engine.supervised_training_step_apex(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=<function _prepare_batch>, model_transform=<function <lambda>>, output_transform=<function <lambda>>, gradient_accumulation_steps=1, model_fn=<function <lambda>>)[source]#

Factory function for supervised training using apex.

Parameters
  • model (Module) – the model to train.

  • optimizer (Optimizer) – the optimizer to use.

  • loss_fn (Union[Callable[[Any, Any], Tensor], Module]) – the loss function that receives y_pred and y, and returns the loss as a tensor.

  • device (Optional[Union[str, device]]) – device type specification (default: None). Applies to batches after starting the engine. Model will not be moved. Device can be CPU, GPU.

  • non_blocking (bool) – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.

  • prepare_batch (Callable) – function that receives batch, device, non_blocking and outputs tuple of tensors (batch_x, batch_y).

  • model_transform (Callable[[Any], Any]) – function that receives the output from the model and convert it into the form as required by the loss function

  • output_transform (Callable[[Any, Any, Any, Tensor], Any]) – function that receives ‘x’, ‘y’, ‘y_pred’, ‘loss’ and returns value to be assigned to engine’s state.output after each iteration. Default is returning loss.item().

  • gradient_accumulation_steps (int) – Number of steps the gradients should be accumulated across. (default: 1 (means no gradient accumulation))

  • model_fn (Callable[[Module, Any], Any]) – the model function that receives model and x, and returns y_pred.

Returns

update function.

Return type

Callable

Examples

from ignite.engine import Engine, supervised_training_step_apex

model = ...
optimizer = ...
loss_fn = ...

update_fn = supervised_training_step_apex(model, optimizer, loss_fn, 'cuda')
trainer = Engine(update_fn)

New in version 0.4.5.

Changed in version 0.4.7: Added Gradient Accumulation.

Changed in version 0.4.11: Added model_transform to transform model’s output

Changed in version 0.4.13: Added model_fn to customize model’s application on the sample