• Docs >
  • ignite.contrib.engines
Shortcuts

ignite.contrib.engines#

Contribution module of engines

class ignite.contrib.engines.Tbptt_Events(value)[source]#

Aditional tbptt events.

Additional events for truncated backpropagation throught time dedicated trainer.

TIME_ITERATION_COMPLETED = 'time_iteration_completed'#
TIME_ITERATION_STARTED = 'time_iteration_started'#
ignite.contrib.engines.create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, dim=0, device=None, non_blocking=False, prepare_batch=<function _prepare_batch>)[source]#

Create a trainer for truncated backprop through time supervised models.

Training recurrent model on long sequences is computationally intensive as it requires to process the whole sequence before getting a gradient. However, when the training loss is computed over many outputs (X to many), there is an opportunity to compute a gradient over a subsequence. This is known as truncated backpropagation through time. This supervised trainer apply gradient optimization step every tbtt_step time steps of the sequence, while backpropagating through the same tbtt_step time steps.

Parameters
  • model (torch.nn.Module) – the model to train.

  • optimizer (torch.optim.Optimizer) – the optimizer to use.

  • loss_fn (torch.nn loss function) – the loss function to use.

  • tbtt_step (int) – the length of time chunks (last one may be smaller).

  • dim (int) – axis representing the time dimension.

  • device (str, optional) – device type specification (default: None). Applies to both model and batches.

  • non_blocking (bool, optional) – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.

  • prepare_batch (callable, optional) – function that receives batch, device, non_blocking and outputs tuple of tensors (batch_x, batch_y).

Returns

a trainer engine with supervised update function.

Return type

Engine