ignite.contrib.engines#
Contribution module of engines
- class ignite.contrib.engines.Tbptt_Events(value)[source]#
Aditional tbptt events.
Additional events for truncated backpropagation throught time dedicated trainer.
- TIME_ITERATION_COMPLETED = 'time_iteration_completed'#
- TIME_ITERATION_STARTED = 'time_iteration_started'#
- ignite.contrib.engines.create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, dim=0, device=None, non_blocking=False, prepare_batch=<function _prepare_batch>)[source]#
Create a trainer for truncated backprop through time supervised models.
Training recurrent model on long sequences is computationally intensive as it requires to process the whole sequence before getting a gradient. However, when the training loss is computed over many outputs (X to many), there is an opportunity to compute a gradient over a subsequence. This is known as truncated backpropagation through time. This supervised trainer apply gradient optimization step every tbtt_step time steps of the sequence, while backpropagating through the same tbtt_step time steps.
- Parameters
model (torch.nn.Module) – the model to train.
optimizer (torch.optim.Optimizer) – the optimizer to use.
loss_fn (torch.nn loss function) – the loss function to use.
tbtt_step (int) – the length of time chunks (last one may be smaller).
dim (int) – axis representing the time dimension.
device (str, optional) – device type specification (default: None). Applies to both model and batches.
non_blocking (bool, optional) – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional) – function that receives batch, device, non_blocking and outputs tuple of tensors (batch_x, batch_y).
- Returns
a trainer engine with supervised update function.
- Return type