Shortcuts

Engine#

class ignite.engine.engine.Engine(process_function)[source]#

Runs a given process_function over each batch of a dataset, emitting events as it goes.

Parameters

process_function (Callable[[Engine, Any], Any]) – A function receiving a handle to the engine and the current batch in each iteration, and returns data to be stored in the engine’s state.

state#

object that is used to pass internal and user-defined state between event handlers. It is created with the engine and its attributes (e.g. state.iteration, state.epoch etc) are reset on every run().

last_event_name#

last event name triggered by the engine.

Note

Engine implementation has changed in v0.4.10 with “interrupt/resume” feature. Engine may behave differently on certain corner cases compared to the one from v0.4.9 and before. In such case, you can set Engine.interrupt_resume_enabled = False to restore previous behaviour.

Examples

Create a basic trainer

def update_model(engine, batch):
    inputs, targets = batch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(update_model)

@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training(engine):
    batch_loss = engine.state.output
    lr = optimizer.param_groups[0]['lr']
    e = engine.state.epoch
    n = engine.state.max_epochs
    i = engine.state.iteration
    print(f"Epoch {e}/{n} : {i} - batch loss: {batch_loss}, lr: {lr}")

trainer.run(data_loader, max_epochs=5)

> Epoch 1/5 : 100 - batch loss: 0.10874069479016124, lr: 0.01
> ...
> Epoch 2/5 : 1700 - batch loss: 0.4217900575859437, lr: 0.01

Create a basic evaluator to compute metrics

from ignite.metrics import Accuracy

def predict_on_batch(engine, batch)
    model.eval()
    with torch.no_grad():
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)

    return y_pred, y

evaluator = Engine(predict_on_batch)
Accuracy().attach(evaluator, "val_acc")
evaluator.run(val_dataloader)

Compute image mean/std on training dataset

from ignite.metrics import Average

def compute_mean_std(engine, batch):
    b, c, *_ = batch['image'].shape
    data = batch['image'].reshape(b, c, -1).to(dtype=torch.float64)
    mean = torch.mean(data, dim=-1).sum(dim=0)
    mean2 = torch.mean(data ** 2, dim=-1).sum(dim=0)
    return {"mean": mean, "mean^2": mean2}

compute_engine = Engine(compute_mean_std)
img_mean = Average(output_transform=lambda output: output['mean'])
img_mean.attach(compute_engine, 'mean')
img_mean2 = Average(output_transform=lambda output: output['mean^2'])
img_mean2.attach(compute_engine, 'mean2')
state = compute_engine.run(train_loader)
state.metrics['std'] = torch.sqrt(state.metrics['mean2'] - state.metrics['mean'] ** 2)
mean = state.metrics['mean'].tolist()
std = state.metrics['std'].tolist()

Resume engine’s run from a state. User can load a state_dict and run engine starting from loaded state :

# Restore from an epoch
state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)}
# or an iteration
# state_dict = {"iteration": 500, "max_epochs": 100, "epoch_length": len(data_loader)}

trainer = Engine(...)
trainer.load_state_dict(state_dict)
trainer.run(data)

Methods

add_event_handler

Add an event handler to be executed when the specified event is fired.

fire_event

Execute all the handlers associated with given event.

has_event_handler

Check if the specified event has the specified handler.

interrupt

Sends interrupt signal to the engine, so that it interrupts the run after the current iteration.

load_state_dict

Setups engine from state_dict.

on

Decorator shortcut for add_event_handler().

register_events

Add events that can be fired.

remove_event_handler

Remove event handler handler from registered handlers of the engine

run

Runs the process_function over the passed data.

set_data

Method to set data.

state_dict

Returns a dictionary containing engine's state: "seed", "epoch_length", "max_epochs" and "iteration" and other state values defined by engine.state_dict_user_keys

terminate

Sends terminate signal to the engine, so that it terminates completely the run.

terminate_epoch

Sends terminate signal to the engine, so that it terminates the current epoch.

add_event_handler(event_name, handler, *args, **kwargs)[source]#

Add an event handler to be executed when the specified event is fired.

Parameters
  • event_name (Any) – An event or a list of events to attach the handler. Valid events are from Events or any event_name added by register_events().

  • handler (Callable) – the callable event handler that should be invoked. No restrictions on its signature. The first argument can be optionally engine, the Engine object, handler is bound to.

  • args (Any) – optional args to be passed to handler.

  • kwargs (Any) – optional keyword args to be passed to handler.

Returns

RemovableEventHandle, which can be used to remove the handler.

Return type

RemovableEventHandle

Note

Note that other arguments can be passed to the handler in addition to the *args and **kwargs passed here, for example during EXCEPTION_RAISED.

Examples

engine = Engine(process_function)

def print_epoch(engine):
    print(f"Epoch: {engine.state.epoch}")

engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)

events_list = Events.EPOCH_COMPLETED | Events.COMPLETED

def execute_something():
    # do some thing not related to engine
    pass

engine.add_event_handler(events_list, execute_something)

Note

Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine. See Events for more details.

fire_event(event_name)[source]#

Execute all the handlers associated with given event.

This method executes all handlers associated with the event event_name. This is the method used in run() to call the core events found in Events.

Custom events can be fired if they have been registered before with register_events(). The engine state attribute should be used to exchange “dynamic” data among process_function and handlers.

This method is called automatically for core events. If no custom events are used in the engine, there is no need for the user to call the method.

Parameters

event_name (Any) – event for which the handlers should be executed. Valid events are from Events or any event_name added by register_events().

Return type

None

has_event_handler(handler, event_name=None)[source]#

Check if the specified event has the specified handler.

Parameters
  • handler (Callable) – the callable event handler.

  • event_name (Optional[Any]) – The event the handler attached to. Set this to None to search all events.

Return type

bool

interrupt()[source]#

Sends interrupt signal to the engine, so that it interrupts the run after the current iteration. The run can be resumed by calling run(). Data iteration will continue from the interrupted state.

Examples

from ignite.engine import Engine, Events

data = range(10)
max_epochs = 3

def check_input_data(e, b):
    print(f"Epoch {engine.state.epoch}, Iter {engine.state.iteration} | data={b}")
    i = (e.state.iteration - 1) % len(data)
    assert b == data[i]

engine = Engine(check_input_data)

@engine.on(Events.ITERATION_COMPLETED(every=11))
def call_interrupt():
    engine.interrupt()

print("Start engine run with interruptions:")
state = engine.run(data, max_epochs=max_epochs)
print("1 Engine run is interrupted at ", state.epoch, state.iteration)
state = engine.run(data, max_epochs=max_epochs)
print("2 Engine run is interrupted at ", state.epoch, state.iteration)
state = engine.run(data, max_epochs=max_epochs)
print("3 Engine ended the run at ", state.epoch, state.iteration)
Output
Start engine run with interruptions:
Epoch 1, Iter 1 | data=0
Epoch 1, Iter 2 | data=1
Epoch 1, Iter 3 | data=2
Epoch 1, Iter 4 | data=3
Epoch 1, Iter 5 | data=4
Epoch 1, Iter 6 | data=5
Epoch 1, Iter 7 | data=6
Epoch 1, Iter 8 | data=7
Epoch 1, Iter 9 | data=8
Epoch 1, Iter 10 | data=9
Epoch 2, Iter 11 | data=0
1 Engine run is interrupted at  2 11
Epoch 2, Iter 12 | data=1
Epoch 2, Iter 13 | data=2
Epoch 2, Iter 14 | data=3
Epoch 2, Iter 15 | data=4
Epoch 2, Iter 16 | data=5
Epoch 2, Iter 17 | data=6
Epoch 2, Iter 18 | data=7
Epoch 2, Iter 19 | data=8
Epoch 2, Iter 20 | data=9
Epoch 3, Iter 21 | data=0
Epoch 3, Iter 22 | data=1
2 Engine run is interrupted at  3 22
Epoch 3, Iter 23 | data=2
Epoch 3, Iter 24 | data=3
Epoch 3, Iter 25 | data=4
Epoch 3, Iter 26 | data=5
Epoch 3, Iter 27 | data=6
Epoch 3, Iter 28 | data=7
Epoch 3, Iter 29 | data=8
Epoch 3, Iter 30 | data=9
3 Engine ended the run at  3 30

New in version 0.4.10.

Return type

None

load_state_dict(state_dict)[source]#

Setups engine from state_dict.

State dictionary should contain keys: iteration or epoch and max_epochs, epoch_length and seed. If engine.state_dict_user_keys contains keys, they should be also present in the state dictionary. Iteration and epoch values are 0-based: the first iteration or epoch is zero.

This method does not remove any custom attributs added by user.

Parameters

state_dict (Mapping) – a dict with parameters

Return type

None

# Restore from the 4rd epoch
state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)}
# or 500th iteration
# state_dict = {"iteration": 499, "max_epochs": 100, "epoch_length": len(data_loader)}

trainer = Engine(...)
trainer.load_state_dict(state_dict)
trainer.run(data)
on(event_name, *args, **kwargs)[source]#

Decorator shortcut for add_event_handler().

Parameters
  • event_name (Any) – An event to attach the handler to. Valid events are from Events or any event_name added by register_events().

  • args (Any) – optional args to be passed to handler.

  • kwargs (Any) – optional keyword args to be passed to handler.

Return type

Callable

Examples

engine = Engine(process_function)

@engine.on(Events.EPOCH_COMPLETED)
def print_epoch():
    print(f"Epoch: {engine.state.epoch}")

@engine.on(Events.EPOCH_COMPLETED | Events.COMPLETED)
def execute_something():
    # do some thing not related to engine
    pass
register_events(*event_names, event_to_attr=None)[source]#

Add events that can be fired.

Registering an event will let the user trigger these events at any point. This opens the door to make the run() loop even more configurable.

By default, the events from Events are registered.

Parameters
  • event_names (Union[List[str], List[EventEnum]]) – Defines the name of the event being supported. New events can be a str or an object derived from EventEnum. See example below.

  • event_to_attr (Optional[dict]) – A dictionary to map an event to a state attribute.

Return type

None

Examples

from ignite.engine import Engine, Events, EventEnum

class CustomEvents(EventEnum):
    FOO_EVENT = "foo_event"
    BAR_EVENT = "bar_event"

def process_function(e, batch):
    # ...
    trainer.fire_event("bwd_event")
    loss.backward()
    # ...
    trainer.fire_event("opt_event")
    optimizer.step()

trainer = Engine(process_function)
trainer.register_events(*CustomEvents)
trainer.register_events("bwd_event", "opt_event")

@trainer.on(Events.EPOCH_COMPLETED)
def trigger_custom_event():
    if required(...):
        trainer.fire_event(CustomEvents.FOO_EVENT)
    else:
        trainer.fire_event(CustomEvents.BAR_EVENT)

@trainer.on(CustomEvents.FOO_EVENT)
def do_foo_op():
    # ...

@trainer.on(CustomEvents.BAR_EVENT)
def do_bar_op():
    # ...

Example with State Attribute:

from enum import Enum
from ignite.engine import Engine, EventEnum

class TBPTT_Events(EventEnum):
    TIME_ITERATION_STARTED = "time_iteration_started"
    TIME_ITERATION_COMPLETED = "time_iteration_completed"

TBPTT_event_to_attr = {
    TBPTT_Events.TIME_ITERATION_STARTED: 'time_iteration',
    TBPTT_Events.TIME_ITERATION_COMPLETED: 'time_iteration'
}

engine = Engine(process_function)
engine.register_events(*TBPTT_Events, event_to_attr=TBPTT_event_to_attr)
engine.run(data)
# engine.state contains an attribute time_iteration, which can be accessed
# using engine.state.time_iteration
remove_event_handler(handler, event_name)[source]#

Remove event handler handler from registered handlers of the engine

Parameters
  • handler (Callable) – the callable event handler that should be removed

  • event_name (Any) – The event the handler attached to.

Return type

None

run(data=None, max_epochs=None, epoch_length=None, seed=None)[source]#

Runs the process_function over the passed data.

Engine has a state and the following logic is applied in this function:

  • At the first call, new state is defined by max_epochs, epoch_length, seed, if provided. A timer for total and per-epoch time is initialized when Events.STARTED is handled.

  • If state is already defined such that there are iterations to run until max_epochs and no input arguments provided, state is kept and used in the function.

  • If state is defined and engine is “done” (no iterations to run until max_epochs), a new state is defined.

  • If state is defined, engine is NOT “done”, then input arguments if provided override defined state.

Parameters
  • data (Optional[Iterable]) – Collection of batches allowing repeated iteration (e.g., list or DataLoader). If not provided, then epoch_length is required and batch argument of process_function will be None.

  • max_epochs (Optional[int]) – Max epochs to run for (default: None). If a new state should be created (first run or run again from ended engine), it’s default value is 1. If run is resuming from a state, provided max_epochs will be taken into account and should be larger than engine.state.max_epochs.

  • epoch_length (Optional[int]) – Number of iterations to count as one epoch. By default, it can be set as len(data). If data is an iterator and epoch_length is not set, then it will be automatically determined as the iteration on which data iterator raises StopIteration. This argument should not change if run is resuming from a state.

  • seed (Optional[int]) – Deprecated argument. Please, use torch.manual_seed or manual_seed().

Returns

output state.

Return type

State

Note

User can dynamically preprocess input batch at ITERATION_STARTED and store output batch in engine.state.batch. Latter is passed as usually to process_function as argument:

trainer = ...

@trainer.on(Events.ITERATION_STARTED)
def switch_batch(engine):
    engine.state.batch = preprocess_batch(engine.state.batch)

Restart the training from the beginning. User can reset max_epochs = None:

# ...
trainer.run(train_loader, max_epochs=5)

# Reset model weights etc. and restart the training
trainer.state.max_epochs = None
trainer.run(train_loader, max_epochs=2)
set_data(data)[source]#

Method to set data. After calling the method the next batch passed to processing_function is from newly provided data. Please, note that epoch length is not modified.

Parameters

data (Union[Iterable, DataLoader]) – Collection of batches allowing repeated iteration (e.g., list or DataLoader).

Return type

None

Examples

User can switch data provider during the training:

data1 = ...
data2 = ...

switch_iteration = 5000

def train_step(e, batch):
    # when iteration <= switch_iteration
    # batch is from data1
    # when iteration > switch_iteration
    # batch is from data2
    ...

trainer = Engine(train_step)

@trainer.on(Events.ITERATION_COMPLETED(once=switch_iteration))
def switch_dataloader():
    trainer.set_data(data2)

trainer.run(data1, max_epochs=100)
state_dict()[source]#

Returns a dictionary containing engine’s state: “seed”, “epoch_length”, “max_epochs” and “iteration” and other state values defined by engine.state_dict_user_keys

engine = Engine(...)
engine.state_dict_user_keys.append("alpha")
engine.state_dict_user_keys.append("beta")
...

@engine.on(Events.STARTED)
def init_user_value(_):
     engine.state.alpha = 0.1
     engine.state.beta = 1.0

@engine.on(Events.COMPLETED)
def save_engine(_):
    state_dict = engine.state_dict()
    assert "alpha" in state_dict and "beta" in state_dict
    torch.save(state_dict, "/tmp/engine.pt")
Returns

a dictionary containing engine’s state

Return type

OrderedDict

terminate()[source]#

Sends terminate signal to the engine, so that it terminates completely the run. The run is terminated after the event on which terminate method was called. The following events are triggered:

Examples

from ignite.engine import Engine, Events

def func(engine, batch):
    print(engine.state.epoch, engine.state.iteration, " | ", batch)

max_epochs = 4
data = range(10)
engine = Engine(func)

@engine.on(Events.ITERATION_COMPLETED(once=14))
def terminate():
    print(f"-> terminate at iteration: {engine.state.iteration}")
    engine.terminate()

print("Start engine run:")
state = engine.run(data, max_epochs=max_epochs)
print("1 Engine run is terminated at ", state.epoch, state.iteration)
state = engine.run(data, max_epochs=max_epochs)
print("2 Engine ended the run at ", state.epoch, state.iteration)
Output
Start engine run:
1 1  |  0
1 2  |  1
1 3  |  2
1 4  |  3
1 5  |  4
1 6  |  5
1 7  |  6
1 8  |  7
1 9  |  8
1 10  |  9
2 11  |  0
2 12  |  1
2 13  |  2
2 14  |  3
-> terminate at iteration: 14
1 Engine run is terminated at  2 14
3 15  |  0
3 16  |  1
3 17  |  2
3 18  |  3
3 19  |  4
3 20  |  5
3 21  |  6
3 22  |  7
3 23  |  8
3 24  |  9
4 25  |  0
4 26  |  1
4 27  |  2
4 28  |  3
4 29  |  4
4 30  |  5
4 31  |  6
4 32  |  7
4 33  |  8
4 34  |  9
2 Engine ended the run at  4 34

Changed in version 0.4.10: Behaviour changed, for details see https://github.com/pytorch/ignite/issues/2669

Return type

None

terminate_epoch()[source]#

Sends terminate signal to the engine, so that it terminates the current epoch. The run continues from the next epoch. The following events are triggered:

Return type

None