ignite.engine#
Main module of the library containing:
ignite.engine.engine
Runs a given |
ignite.engine.events
Single Event containing a filter, specifying whether the event should be run at the current event (if the event type is correct) |
|
Base class for all |
|
Events that are fired by the |
|
An object that is used to pass internal and user-defined state between event handlers. |
|
|
Collection of events stacked by operator __or__. |
A weakref handle to remove a registered event. |
ignite.engine.deterministic (helper methods for deterministic training)
Helper function to replace current batch sampler of the dataloader by a new batch sampler. |
|
Helper decorator to keep random state of torch, numpy and random intact while executing a function. |
|
Reproducible batch sampler. |
|
Deterministic engine derived from |
and helper methods to define supervised trainer and evaluator:
Factory function for creating a trainer for supervised models. |
|
Factory function for creating an evaluator for supervised models. |
More details about those structures can be found in Concepts.
- class ignite.engine.engine.Engine(process_function)[source]#
Runs a given
process_function
over each batch of a dataset, emitting events as it goes.- Parameters
process_function (callable) – A function receiving a handle to the engine and the current batch in each iteration, and returns data to be stored in the engine’s state.
- state#
object that is used to pass internal and user-defined state between event handlers. It is created with the engine and its attributes (e.g.
state.iteration
,state.epoch
etc) are reset on everyrun()
.- Type
Examples
Create a basic trainer
def update_model(engine, batch): inputs, targets = batch optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() return loss.item() trainer = Engine(update_model) @trainer.on(Events.ITERATION_COMPLETED(every=100)) def log_training(engine): batch_loss = engine.state.output lr = optimizer.param_groups[0]['lr'] e = engine.state.epoch n = engine.state.max_epochs i = engine.state.iteration print(f"Epoch {e}/{n} : {i} - batch loss: {batch_loss}, lr: {lr}") trainer.run(data_loader, max_epochs=5) > Epoch 1/5 : 100 - batch loss: 0.10874069479016124, lr: 0.01 > ... > Epoch 2/5 : 1700 - batch loss: 0.4217900575859437, lr: 0.01
Create a basic evaluator to compute metrics
from ignite.metrics import Accuracy def predict_on_batch(engine, batch) model.eval() with torch.no_grad(): x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) return y_pred, y evaluator = Engine(predict_on_batch) Accuracy().attach(evaluator, "val_acc") evaluator.run(val_dataloader)
Compute image mean/std on training dataset
from ignite.metrics import Average def compute_mean_std(engine, batch): b, c, *_ = batch['image'].shape data = batch['image'].reshape(b, c, -1).to(dtype=torch.float64) mean = torch.mean(data, dim=-1).sum(dim=0) mean2 = torch.mean(data ** 2, dim=-1).sum(dim=0) return {"mean": mean, "mean^2": mean2} compute_engine = Engine(compute_mean_std) img_mean = Average(output_transform=lambda output: output['mean']) img_mean.attach(compute_engine, 'mean') img_mean2 = Average(output_transform=lambda output: output['mean^2']) img_mean2.attach(compute_engine, 'mean2') state = compute_engine.run(train_loader) state.metrics['std'] = torch.sqrt(state.metrics['mean2'] - state.metrics['mean'] ** 2) mean = state.metrics['mean'].tolist() std = state.metrics['std'].tolist()
Resume engine’s run from a state. User can load a state_dict and run engine starting from loaded state :
# Restore from an epoch state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)} # or an iteration # state_dict = {"iteration": 500, "max_epochs": 100, "epoch_length": len(data_loader)} trainer = Engine(...) trainer.load_state_dict(state_dict) trainer.run(data)
- add_event_handler(event_name, handler, *args, **kwargs)[source]#
Add an event handler to be executed when the specified event is fired.
- Parameters
event_name (Any) – An event or a list of events to attach the handler. Valid events are from
Events
or anyevent_name
added byregister_events()
.handler (callable) – the callable event handler that should be invoked. No restrictions on its signature. The first argument can be optionally engine, the
Engine
object, handler is bound to.*args – optional args to be passed to
handler
.**kwargs – optional keyword args to be passed to
handler
.
- Return type
Note
Note that other arguments can be passed to the handler in addition to the *args and **kwargs passed here, for example during
EXCEPTION_RAISED
.- Returns
RemovableEventHandle
, which can be used to remove the handler.- Parameters
- Return type
Example usage:
engine = Engine(process_function) def print_epoch(engine): print(f"Epoch: {engine.state.epoch}") engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch) events_list = Events.EPOCH_COMPLETED | Events.COMPLETED def execute_something(): # do some thing not related to engine pass engine.add_event_handler(events_list, execute_something)
Note
Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine. See
Events
for more details.
- fire_event(event_name)[source]#
Execute all the handlers associated with given event.
This method executes all handlers associated with the event event_name. This is the method used in
run()
to call the core events found inEvents
.Custom events can be fired if they have been registered before with
register_events()
. The engine state attribute should be used to exchange “dynamic” data among process_function and handlers.This method is called automatically for core events. If no custom events are used in the engine, there is no need for the user to call the method.
- Parameters
event_name (Any) – event for which the handlers should be executed. Valid events are from
Events
or any event_name added byregister_events()
.- Return type
None
- has_event_handler(handler, event_name=None)[source]#
Check if the specified event has the specified handler.
- load_state_dict(state_dict)[source]#
Setups engine from state_dict.
State dictionary should contain keys: iteration or epoch and max_epochs, epoch_length. If engine.state_dict_user_keys contains keys, they should be also present in the state dictionary. Iteration and epoch values are 0-based: the first iteration or epoch is zero.
This method does not remove any custom attributs added by user.
- Parameters
state_dict (Mapping) – a dict with parameters
- Return type
None
# Restore from the 4rd epoch state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)} # or 500th iteration # state_dict = {"iteration": 499, "max_epochs": 100, "epoch_length": len(data_loader)} trainer = Engine(...) trainer.load_state_dict(state_dict) trainer.run(data)
- on(event_name, *args, **kwargs)[source]#
Decorator shortcut for add_event_handler.
- Parameters
event_name (Any) – An event to attach the handler to. Valid events are from
Events
or anyevent_name
added byregister_events()
.*args (Any) – optional args to be passed to handler.
**kwargs (Any) – optional keyword args to be passed to handler.
- Return type
Example usage:
engine = Engine(process_function) @engine.on(Events.EPOCH_COMPLETED) def print_epoch(): print(f"Epoch: {engine.state.epoch}") @engine.on(Events.EPOCH_COMPLETED | Events.COMPLETED) def execute_something(): # do some thing not related to engine pass
- register_events(*event_names, event_to_attr=None)[source]#
Add events that can be fired.
Registering an event will let the user trigger these events at any point. This opens the door to make the
run()
loop even more configurable.By default, the events from
Events
are registered.- Parameters
- Return type
None
Example usage:
from ignite.engine import Engine, Events, EventEnum class CustomEvents(EventEnum): FOO_EVENT = "foo_event" BAR_EVENT = "bar_event" def process_function(e, batch): # ... trainer.fire_event("bwd_event") loss.backward() # ... trainer.fire_event("opt_event") optimizer.step() trainer = Engine(process_function) trainer.register_events(*CustomEvents) trainer.register_events("bwd_event", "opt_event") @trainer.on(Events.EPOCH_COMPLETED) def trigger_custom_event(): if required(...): trainer.fire_event(CustomEvents.FOO_EVENT) else: trainer.fire_event(CustomEvents.BAR_EVENT) @trainer.on(CustomEvents.FOO_EVENT) def do_foo_op(): # ... @trainer.on(CustomEvents.BAR_EVENT) def do_bar_op(): # ...
Example with State Attribute:
from enum import Enum from ignite.engine import Engine, EventEnum class TBPTT_Events(EventEnum): TIME_ITERATION_STARTED = "time_iteration_started" TIME_ITERATION_COMPLETED = "time_iteration_completed" TBPTT_event_to_attr = { TBPTT_Events.TIME_ITERATION_STARTED: 'time_iteration', TBPTT_Events.TIME_ITERATION_COMPLETED: 'time_iteration' } engine = Engine(process_function) engine.register_events(*TBPTT_Events, event_to_attr=TBPTT_event_to_attr) engine.run(data) # engine.state contains an attribute time_iteration, which can be accessed using engine.state.time_iteration
- remove_event_handler(handler, event_name)[source]#
Remove event handler handler from registered handlers of the engine
- Parameters
handler (callable) – the callable event handler that should be removed
event_name (Any) – The event the handler attached to.
- Return type
None
- run(data, max_epochs=None, epoch_length=None, seed=None)[source]#
Runs the process_function over the passed data.
Engine has a state and the following logic is applied in this function:
- At the first call, new state is defined by max_epochs, epoch_length if provided. A timer for
total and per-epoch time is initialized when Events.STARTED is handled.
If state is already defined such that there are iterations to run until max_epochs and no input arguments provided, state is kept and used in the function.
If state is defined and engine is “done” (no iterations to run until max_epochs), a new state is defined.
If state is defined, engine is NOT “done”, then input arguments if provided override defined state.
- Parameters
data (Iterable) – Collection of batches allowing repeated iteration (e.g., list or DataLoader).
max_epochs (int, optional) – Max epochs to run for (default: None). If a new state should be created (first run or run again from ended engine), it’s default value is 1. If run is resuming from a state, provided max_epochs will be taken into account and should be larger than engine.state.max_epochs.
epoch_length (int, optional) – Number of iterations to count as one epoch. By default, it can be set as len(data). If data is an iterator and epoch_length is not set, then it will be automatically determined as the iteration on which data iterator raises StopIteration. This argument should not change if run is resuming from a state.
seed (int, optional) – Deprecated argument since v0.4.0 and will be removed in v0.5.0. Please, use torch.manual_seed or
manual_seed()
.
- Returns
output state.
- Return type
Note
User can dynamically preprocess input batch at
ITERATION_STARTED
and store output batch in engine.state.batch. Latter is passed as usually to process_function as argument:trainer = ... @trainer.on(Events.ITERATION_STARTED) def switch_batch(engine): engine.state.batch = preprocess_batch(engine.state.batch)
Restart the training from the beginning. User can reset max_epochs = None:
# ... trainer.run(train_loader, max_epochs=5) # Reset model weights etc. and restart the training trainer.state.max_epochs = None trainer.run(train_loader, max_epochs=2)
- set_data(data)[source]#
Method to set data. After calling the method the next batch passed to processing_function is from newly provided data. Please, note that epoch length is not modified.
- Parameters
data (Iterable) – Collection of batches allowing repeated iteration (e.g., list or DataLoader).
- Return type
None
- Example usage:
User can switch data provider during the training:
data1 = ... data2 = ... switch_iteration = 5000 def train_step(e, batch): # when iteration <= switch_iteration # batch is from data1 # when iteration > switch_iteration # batch is from data2 ... trainer = Engine(train_step) @trainer.on(Events.ITERATION_COMPLETED(once=switch_iteration)) def switch_dataloader(): trainer.set_data(data2) trainer.run(data1, max_epochs=100)
- state_dict()[source]#
Returns a dictionary containing engine’s state: “epoch_length”, “max_epochs” and “iteration” and other state values defined by engine.state_dict_user_keys
engine = Engine(...) engine.state_dict_user_keys.append("alpha") engine.state_dict_user_keys.append("beta") ... @engine.on(Events.STARTED) def init_user_value(_): engine.state.alpha = 0.1 engine.state.beta = 1.0 @engine.on(Events.COMPLETED) def save_engine(_): state_dict = engine.state_dict() assert "alpha" in state_dict and "beta" in state_dict torch.save(state_dict, "/tmp/engine.pt")
- Returns
a dictionary containing engine’s state
- Return type
OrderedDict
- ignite.engine.create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=<function _prepare_batch>, output_transform=<function <lambda>>, deterministic=False)[source]#
Factory function for creating a trainer for supervised models.
- Parameters
model (torch.nn.Module) – the model to train.
optimizer (torch.optim.Optimizer) – the optimizer to use.
loss_fn (torch.nn loss function) – the loss function to use.
device (str, optional) – device type specification (default: None). Applies to batches after starting the engine. Model will not be moved. Device can be CPU, GPU or TPU.
non_blocking (bool, optional) – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional) – function that receives batch, device, non_blocking and outputs tuple of tensors (batch_x, batch_y).
output_transform (callable, optional) – function that receives ‘x’, ‘y’, ‘y_pred’, ‘loss’ and returns value to be assigned to engine’s state.output after each iteration. Default is returning loss.item().
deterministic (bool, optional) – if True, returns deterministic engine of type
DeterministicEngine
, otherwiseEngine
(default: False).
- Return type
Note
engine.state.output for this engine is defined by output_transform parameter and is the loss of the processed batch by default.
Warning
The internal use of device has changed. device will now only be used to move the input data to the correct device. The model should be moved by the user before creating an optimizer. For more information see:
- Returns
a trainer engine with supervised update function.
- Return type
- Parameters
- ignite.engine.create_supervised_evaluator(model, metrics=None, device=None, non_blocking=False, prepare_batch=<function _prepare_batch>, output_transform=<function <lambda>>)[source]#
Factory function for creating an evaluator for supervised models.
- Parameters
model (torch.nn.Module) – the model to train.
metrics (dict of str -
Metric
) – a map of metric names to Metrics.device (str, optional) – device type specification (default: None). Applies to batches after starting the engine. Model will not be moved.
non_blocking (bool, optional) – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional) – function that receives batch, device, non_blocking and outputs tuple of tensors (batch_x, batch_y).
output_transform (callable, optional) – function that receives ‘x’, ‘y’, ‘y_pred’ and returns value to be assigned to engine’s state.output after each iteration. Default is returning (y_pred, y,) which fits output expected by metrics. If you change it you should use output_transform in metrics.
- Return type
Note
engine.state.output for this engine is defind by output_transform parameter and is a tuple of (batch_pred, batch_y) by default.
Warning
The internal use of device has changed. device will now only be used to move the input data to the correct device. The model should be moved by the user before creating an optimizer.
For more information see:
Resuming the training#
It is possible to resume the training from a checkpoint and approximately reproduce original run’s behaviour.
Using Ignite, this can be easily done using Checkpoint
handler. Engine provides two methods
to serialize and deserialize its internal state state_dict()
and
load_state_dict()
. In addition to serializing model, optimizer, lr scheduler etc user can
store the trainer and then resume the training. For example:
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver
trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...
data_loader = ...
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
handler = Checkpoint(to_save, DiskSaver('/tmp/training', create_dir=True))
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
trainer.run(data_loader, max_epochs=100)
ls /tmp/training
> "checkpoint_50000.pt"
We can then restore the training from the last checkpoint.
from ignite.handlers import Checkpoint
trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...
data_loader = ...
to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
checkpoint = torch.load(checkpoint_file)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
trainer.run(train_loader, max_epochs=100)
It is also possible to store checkpoints every N iterations and continue the training from one of these checkpoints, i.e from iteration.
Complete examples that resumes the training from a checkpoint can be found here:
ignite.engine.events#
- class ignite.engine.events.CallableEventWithFilter(value, event_filter=None, name=None)[source]#
Single Event containing a filter, specifying whether the event should be run at the current event (if the event type is correct)
- Parameters
value (str) – The actual enum value. Only needed for internal use. Do not touch!
event_filter (callable) – A function taking the engine and the current event value as input and returning a boolean to indicate whether this event should be executed. Defaults to None, which will result to a function that always returns True
name (str, optional) – The enum-name of the current object. Only needed for internal use. Do not touch!
- name#
The name of the Enum member.
- value#
The value of the Enum member.
- class ignite.engine.events.Events(value)[source]#
Events that are fired by the
Engine
during execution. Built-in events:STARTED : triggered when engine’s run is started
EPOCH_STARTED : triggered when the epoch is started
GET_BATCH_STARTED : triggered before next batch is fetched
GET_BATCH_COMPLETED : triggered after the batch is fetched
ITERATION_STARTED : triggered when an iteration is started
ITERATION_COMPLETED : triggered when the iteration is ended
DATALOADER_STOP_ITERATION : engine’s specific event triggered when dataloader has no more data to provide
EXCEPTION_RAISED : triggered when an exception is encountered
TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch, after receiving a
terminate_epoch()
orterminate()
call.TERMINATE : triggered when the run is about to end completely, after receiving
terminate()
call.EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even when
terminate_epoch()
is called.COMPLETED : triggered when engine’s run is completed
The table below illustrates which events are triggered when various termination methods are called.
Method
EVENT_COMPLETED
TERMINATE_SINGLE_EPOCH
TERMINATE
no termination
✔
✗
✗
✔
✔
✗
✗
✔
✔
Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine:
engine = Engine() # a) custom event filter def custom_event_filter(engine, event): if event in [1, 2, 5, 10, 50, 100]: return True return False @engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter)) def call_on_special_event(engine): # do something on 1, 2, 5, 10, 50, 100 iterations # b) "every" event filter @engine.on(Events.ITERATION_STARTED(every=10)) def call_every(engine): # do something every 10th iteration # c) "once" event filter @engine.on(Events.ITERATION_STARTED(once=50)) def call_once(engine): # do something on 50th iteration
Event filter function event_filter accepts as input engine and event and should return True/False. Argument event is the value of iteration or epoch, depending on which type of Events the function is passed.
Since v0.4.0, user can also combine events with |-operator:
events = Events.STARTED | Events.COMPLETED | Events.ITERATION_STARTED(every=3) engine = ... @engine.on(events) def call_on_events(engine): # do something
Since v0.4.0, custom events defined by user should inherit from
EventEnum
:class CustomEvents(EventEnum): FOO_EVENT = "foo_event" BAR_EVENT = "bar_event"
- class ignite.engine.events.EventEnum(value)[source]#
Base class for all
Events
. User defined custom events should also inherit this class. For example, Custom events based on the loss calculation and backward pass can be created as follows:from ignite.engine import EventEnum class BackpropEvents(EventEnum): BACKWARD_STARTED = 'backward_started' BACKWARD_COMPLETED = 'backward_completed' OPTIM_STEP_COMPLETED = 'optim_step_completed' def update(engine, batch): # ... loss = criterion(y_pred, y) engine.fire_event(BackpropEvents.BACKWARD_STARTED) loss.backward() engine.fire_event(BackpropEvents.BACKWARD_COMPLETED) optimizer.step() engine.fire_event(BackpropEvents.OPTIM_STEP_COMPLETED) # ... trainer = Engine(update) trainer.register_events(*BackpropEvents) @trainer.on(BackpropEvents.BACKWARD_STARTED) def function_before_backprop(engine): # ...
- class ignite.engine.events.State(**kwargs)[source]#
An object that is used to pass internal and user-defined state between event handlers. By default, state contains the following attributes:
state.iteration # 1-based, the first iteration is 1 state.epoch # 1-based, the first epoch is 1 state.seed # seed to set at each epoch state.dataloader # data passed to engine state.epoch_length # optional length of an epoch state.max_epochs # number of epochs to run state.batch # batch passed to `process_function` state.output # output of `process_function` after a single iteration state.metrics # dictionary with defined metrics if any state.times # dictionary with total and per-epoch times fetched on # keys: Events.EPOCH_COMPLETED.name and Events.COMPLETED.name
- Parameters
kwargs (Any) –
- class ignite.engine.events.RemovableEventHandle(event_name, handler, engine)[source]#
A weakref handle to remove a registered event.
A handle that may be used to remove a registered event handler via the remove method, with-statement, or context manager protocol. Returned from
add_event_handler()
.- Parameters
Example usage:
engine = Engine() def print_epoch(engine): print(f"Epoch: {engine.state.epoch}") with engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch): # print_epoch handler registered for a single run engine.run(data) # print_epoch handler is now unregistered
ignite.engine.deterministic#
Deterministic training#
In general, it is rather difficult task to achieve deterministic and reproducible trainings as it relies on multiple aspects, e.g. data version, code version, software environment, hardware etc. According to PyTorch note on randomness: there are some steps to take in order to make computations deterministic on your specific problem on one specific platform and PyTorch release:
setup random state seed
set cudnn to deterministic if applicable
By default, these two options can be enough to run and rerun experiments in a deterministic way. Ignite’s engine does not impact this behaviour.
In this module we provide helper methods and classes to make additional “Dataflow synchronization” to ensure that model sees the same data for a given epoch:
- class ignite.engine.deterministic.DeterministicEngine(process_function)[source]#
Deterministic engine derived from
Engine
.“Deterministic” run is done by adding additional handlers to synchronize the dataflow and overriding some methods of
Engine
:for e in range(num_epochs): set_seed(seed_offset + e) if resume: setup_saved_rng_states() do_single_epoch_iterations(dataloader)
If input data provider is DataLoader, its batch sampler is replaced by
ReproducibleBatchSampler
.for e in range(num_epochs): set_seed(seed_offset + e) setup_sampling(dataloader) if resume: setup_saved_rng_states() do_single_epoch_iterations(dataloader)
Internally, torch.backends.cudnn.deterministic = True and torch.backends.cudnn.benchmark = False are also applied.
For more details about dataflow synchronization, please see Dataflow synchronization.
Note
This class can produce exactly the same dataflow when resuming the run from an epoch (or more precisely from dataflow restart) and using torch DataLoader with num_workers > 1 as data provider.
- Parameters
process_function (Callable) –
- state_dict()[source]#
Returns a dictionary containing engine’s state: “epoch_length”, “max_epochs” and “iteration” and other state values defined by engine.state_dict_user_keys
engine = Engine(...) engine.state_dict_user_keys.append("alpha") engine.state_dict_user_keys.append("beta") ... @engine.on(Events.STARTED) def init_user_value(_): engine.state.alpha = 0.1 engine.state.beta = 1.0 @engine.on(Events.COMPLETED) def save_engine(_): state_dict = engine.state_dict() assert "alpha" in state_dict and "beta" in state_dict torch.save(state_dict, "/tmp/engine.pt")
- Returns
a dictionary containing engine’s state
- Return type
OrderedDict
- class ignite.engine.deterministic.ReproducibleBatchSampler(batch_sampler, start_iteration=None)[source]#
Reproducible batch sampler. This class internally iterates and stores indices of the input batch sampler. This helps to start providing data batches from an iteration in a deterministic way.
Usage:
Setup dataloader with ReproducibleBatchSampler and start providing data batches from an iteration:
from ignite.engine.deterministic import update_dataloader dataloader = update_dataloader(dataloader, ReproducibleBatchSampler(dataloader.batch_sampler)) # rewind dataloader to a specific iteration: dataloader.batch_sampler.start_iteration = start_iteration
- Parameters
batch_sampler (torch.utils.data.sampler.BatchSampler) – batch sampler same as used with torch.utils.data.DataLoader
start_iteration (int, optional) – optional start iteration
- ignite.engine.deterministic.keep_random_state(func)[source]#
Helper decorator to keep random state of torch, numpy and random intact while executing a function. For more details on usage, please see Dataflow synchronization.
- Parameters
func (callable) – function to decorate
- Return type
- ignite.engine.deterministic.update_dataloader(dataloader, new_batch_sampler)[source]#
Helper function to replace current batch sampler of the dataloader by a new batch sampler. Function returns new dataloader with new batch sampler.
- Parameters
dataloader (torch.utils.data.DataLoader) – input dataloader
new_batch_sampler (torch.utils.data.sampler.BatchSampler) – new batch sampler to use
- Returns
DataLoader
- Return type
Dataflow synchronization#
Ignite provides an option to control the dataflow by synchronizing random state on epochs. In this way, for a given iteration/epoch the dataflow can be the same for a given seed. More precisely it is roughly looks like:
for e in range(num_epochs):
set_seed(seed + e)
do_single_epoch_iterations(dataloader)
In addition, if data provider is torch.utils.data.DataLoader
, batch data indices can be made completely deterministic.
Here is a trivial example of usage:
import torch
from torch.utils.data import DataLoader
from ignite.engine import DeterministicEngine, Events
from ignite.utils import manual_seed
def random_train_data_loader(size):
data = torch.arange(0, size)
return DataLoader(data, batch_size=4, shuffle=True)
def print_train_data(engine, batch):
i = engine.state.iteration
e = engine.state.epoch
print("train", e, i, batch.tolist())
trainer = DeterministicEngine(print_train_data)
print("Original Run")
manual_seed(56)
trainer.run(random_train_data_loader(40), max_epochs=2, epoch_length=5)
print("Resumed Run")
# Resume from 2nd epoch
trainer.load_state_dict({"epoch": 1, "epoch_length": 5, "max_epochs": 2, "rng_states": None})
manual_seed(56)
trainer.run(random_train_data_loader(40))
Original Run
train 1 1 [31, 13, 3, 4]
train 1 2 [23, 18, 6, 16]
train 1 3 [10, 8, 33, 36]
train 1 4 [1, 37, 19, 9]
train 1 5 [20, 30, 14, 26]
train 2 6 [29, 35, 38, 34]
train 2 7 [7, 22, 12, 17]
train 2 8 [25, 21, 24, 15]
train 2 9 [39, 5, 2, 28]
train 2 10 [27, 11, 32, 0]
Resumed Run
train 2 6 [29, 35, 38, 34]
train 2 7 [7, 22, 12, 17]
train 2 8 [25, 21, 24, 15]
train 2 9 [39, 5, 2, 28]
train 2 10 [27, 11, 32, 0]
We can see that the data samples are exactly the same between original and resumed runs.
Complete examples that simulates a crash on a defined iteration and resumes the training from a checkpoint can be found here:
Note
In case when input data is torch.utils.data.DataLoader, previous batches are skipped and the first provided batch corresponds to the batch after the checkpoint iteration. Internally, while resuming, previous datapoint indices are just skipped without fetching the data.
Warning
However, while resuming from iteration, random data augmentations are not synchronized in the middle of the epoch and thus batches remaining until the end of the epoch can be different of those from the initial run.
Warning
However, please, keep in mind that there can be an issue with dataflow synchronization on every epoch
if user’s handler synchronizes the random state, for example, by calling periodically torch.manual_seed(seed)
during
the run. This can have an impact on the dataflow:
def random_train_data_generator():
while True:
yield torch.randint(0, 100, size=(1, ))
trainer = DeterministicEngine(print_train_data)
@trainer.on(Events.ITERATION_COMPLETED(every=3))
def user_handler():
# handler synchronizes the random state
torch.manual_seed(12)
a = torch.rand(1)
trainer.run(random_train_data_generator(), max_epochs=3, epoch_length=5);
train 1 1 [32]
train 1 2 [29]
train 1 3 [40]
train 1 4 [3] <---
train 1 5 [22]
train 2 6 [77]
train 2 7 [3] <---
train 2 8 [22]
train 2 9 [77]
train 2 10 [3] <---
train 3 11 [22]
train 3 12 [77]
train 3 13 [3] <---
train 3 14 [22]
train 3 15 [77]
Initially, the function random_train_data_generator()
generates randomly data batches using the random state set
up by trainer
. This is intended behaviour until user_handler()
is called.
After user_handler()
execution, random state is altered and thus random_train_data_generator()
will produce
random batches based on altered random state.
We provide helper decorator keep_random_state()
to save and restore random states for
torch, numpy and random. Therefore, we can deal with described issue using this decorator:
from ignite.engine.deterministic import keep_random_state
@trainer.on(Events.ITERATION_COMPLETED(every=3))
@keep_random_state
def user_handler():
# handler synchronizes the random state
torch.manual_seed(12)
a = torch.rand(1)