import numbers
import warnings
import weakref
from collections.abc import Sequence
from enum import Enum
from types import DynamicClassAttribute
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TYPE_CHECKING, Union
from torch.utils.data import DataLoader
from ignite.engine.utils import _check_signature
if TYPE_CHECKING:
from ignite.engine.engine import Engine
__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State", "EventsList", "RemovableEventHandle"]
[docs]class CallableEventWithFilter:
"""Single Event containing a filter, specifying whether the event should
be run at the current event (if the event type is correct)
Args:
value: The actual enum value. Only needed for internal use. Do not touch!
event_filter: A function taking the engine and the current event value as input and returning a
boolean to indicate whether this event should be executed. Defaults to None, which will result to a
function that always returns `True`
name: The enum-name of the current object. Only needed for internal use. Do not touch!
"""
def __init__(self, value: str, event_filter: Optional[Callable] = None, name: Optional[str] = None) -> None:
self.filter = event_filter
if not hasattr(self, "_value_"):
self._value_ = value
if not hasattr(self, "_name_") and name is not None:
self._name_ = name
# copied to be compatible to enum
@DynamicClassAttribute
def name(self) -> str:
"""The name of the Enum member."""
return self._name_
@DynamicClassAttribute
def value(self) -> str:
"""The value of the Enum member."""
return self._value_
def __call__(
self,
event_filter: Optional[Callable] = None,
every: Optional[int] = None,
once: Optional[Union[int, List]] = None,
before: Optional[int] = None,
after: Optional[int] = None,
) -> "CallableEventWithFilter":
"""
Makes the event class callable and accepts either an arbitrary callable as filter
(which must take in the engine and current event value and return a boolean) or an every or once value
Args:
event_filter: a filter function to check if the event should be executed when
the event type was fired
every: a value specifying how often the event should be fired
once: a value or list of values specifying when the event should be fired (if only once)
before: a value specifying the number of occurrence that event should be fired before
after: a value specifying the number of occurrence that event should be fired after
Returns:
CallableEventWithFilter: A new event having the same value but a different filter function
"""
if (
sum(
(
event_filter is not None,
once is not None,
(every is not None or before is not None or after is not None),
)
)
!= 1
):
raise ValueError("Only one of the input arguments should be specified, except before, after and every")
if (event_filter is not None) and not callable(event_filter):
raise TypeError("Argument event_filter should be a callable")
if (every is not None) and not (isinstance(every, numbers.Integral) and every > 0):
raise ValueError("Argument every should be integer and greater than zero")
if once is not None:
c1 = isinstance(once, numbers.Integral) and once > 0
c2 = isinstance(once, Sequence) and len(once) > 0 and all(isinstance(e, int) and e > 0 for e in once)
if not (c1 or c2):
raise ValueError(
f"Argument once should either be a positive integer or a list of positive integers, got {once}"
)
if (before is not None) and not (isinstance(before, numbers.Integral) and before >= 0):
raise ValueError("Argument before should be integer and greater or equal to zero")
if (after is not None) and not (isinstance(after, numbers.Integral) and after >= 0):
raise ValueError("Argument after should be integer and greater or equal to zero")
if every is not None:
if every == 1:
# Just return the event itself
event_filter = None
else:
event_filter = self.every_event_filter(every)
if once is not None:
event_filter = self.once_event_filter([once] if isinstance(once, int) else once)
if before is not None or after is not None:
if every is not None:
event_filter = self.every_before_and_after_event_filter(every, before, after)
else:
event_filter = self.before_and_after_event_filter(before, after)
# check signature:
if event_filter is not None:
_check_signature(event_filter, "event_filter", "engine", "event")
return CallableEventWithFilter(self.value, event_filter, self.name)
[docs] @staticmethod
def every_event_filter(every: int) -> Callable:
"""A wrapper for every event filter."""
def wrapper(engine: "Engine", event: int) -> bool:
if event % every == 0:
return True
return False
return wrapper
[docs] @staticmethod
def once_event_filter(once: List) -> Callable:
"""A wrapper for once event filter."""
def wrapper(engine: "Engine", event: int) -> bool:
if event in once:
return True
return False
return wrapper
[docs] @staticmethod
def before_and_after_event_filter(before: Optional[int] = None, after: Optional[int] = None) -> Callable:
"""A wrapper for before and after event filter."""
before_: Union[int, float] = float("inf") if before is None else before
after_: int = 0 if after is None else after
def wrapper(engine: "Engine", event: int) -> bool:
if event > after_ and event < before_:
return True
return False
return wrapper
[docs] @staticmethod
def every_before_and_after_event_filter(
every: int, before: Optional[int] = None, after: Optional[int] = None
) -> Callable:
"""A wrapper which triggers for every `every` iterations after `after` and before `before`."""
before_: Union[int, float] = float("inf") if before is None else before
after_: int = 0 if after is None else after
def wrapper(engine: "Engine", event: int) -> bool:
if after_ < event < before_ and (event - after_ - 1) % every == 0:
return True
return False
return wrapper
[docs] @staticmethod
def default_event_filter(engine: "Engine", event: int) -> bool:
"""Default event filter. This method is is deprecated and will be removed. Please, use None instead"""
warnings.warn("Events.default_event_filter is deprecated and will be removed. Please, use None instead")
return True
def __repr__(self) -> str:
out = f"Events.{self.name}"
if self.filter is not None:
out += f"(filter={self.filter})"
return out
def __eq__(self, other: Any) -> bool:
if isinstance(other, CallableEventWithFilter):
return self.name == other.name
elif isinstance(other, str):
return self.name == other
else:
return NotImplemented
def __hash__(self) -> int:
return hash(self._name_)
def __or__(self, other: Any) -> "EventsList":
return EventsList() | self | other
class CallableEvents(CallableEventWithFilter):
# For backward compatibility
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(CallableEvents, self).__init__(*args, **kwargs)
warnings.warn(
"Class ignite.engine.events.CallableEvents is deprecated. It will be removed in 0.4.14. "
"Please, use ignite.engine.EventEnum instead",
DeprecationWarning,
)
[docs]class EventEnum(CallableEventWithFilter, Enum):
"""Base class for all :class:`~ignite.engine.events.Events`. User defined custom events should also inherit
this class.
Examples:
Custom events based on the loss calculation and backward pass can be created as follows:
.. code-block:: python
from ignite.engine import EventEnum
class BackpropEvents(EventEnum):
BACKWARD_STARTED = 'backward_started'
BACKWARD_COMPLETED = 'backward_completed'
OPTIM_STEP_COMPLETED = 'optim_step_completed'
def update(engine, batch):
# ...
loss = criterion(y_pred, y)
engine.fire_event(BackpropEvents.BACKWARD_STARTED)
loss.backward()
engine.fire_event(BackpropEvents.BACKWARD_COMPLETED)
optimizer.step()
engine.fire_event(BackpropEvents.OPTIM_STEP_COMPLETED)
# ...
trainer = Engine(update)
trainer.register_events(*BackpropEvents)
@trainer.on(BackpropEvents.BACKWARD_STARTED)
def function_before_backprop(engine):
# ...
"""
def __new__(cls, value: str) -> "EventEnum":
obj = CallableEventWithFilter.__new__(cls)
obj._value_ = value
return obj
[docs]class Events(EventEnum):
"""Events that are fired by the :class:`~ignite.engine.engine.Engine` during execution. Built-in events:
- STARTED : triggered when engine's run is started
- EPOCH_STARTED : triggered when the epoch is started
- GET_BATCH_STARTED : triggered before next batch is fetched
- GET_BATCH_COMPLETED : triggered after the batch is fetched
- ITERATION_STARTED : triggered when an iteration is started
- ITERATION_COMPLETED : triggered when the iteration is ended
- DATALOADER_STOP_ITERATION : engine's specific event triggered when dataloader has no more data to provide
- EXCEPTION_RAISED : triggered when an exception is encountered
- TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch,
after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or
:meth:`~ignite.engine.engine.Engine.terminate()` call.
- TERMINATE : triggered when the run is about to end completely,
after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call.
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- COMPLETED : triggered when engine's run is completed
The table below illustrates which events are triggered when various termination methods are called.
.. list-table::
:widths: 24 25 33 18
:header-rows: 1
* - Method
- EVENT_COMPLETED
- TERMINATE_SINGLE_EPOCH
- TERMINATE
* - no termination
- ✔
- ✗
- ✗
* - :meth:`~ignite.engine.engine.Engine.terminate_epoch()`
- ✔
- ✔
- ✗
* - :meth:`~ignite.engine.engine.Engine.terminate()`
- ✗
- ✔
- ✔
Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine:
.. code-block:: python
engine = Engine()
# a) custom event filter
def custom_event_filter(engine, event):
if event in [1, 2, 5, 10, 50, 100]:
return True
return False
@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))
def call_on_special_event(engine):
# do something on 1, 2, 5, 10, 50, 100 iterations
# b) "every" event filter
@engine.on(Events.ITERATION_STARTED(every=10))
def call_every(engine):
# do something every 10th iteration
# c) "once" event filter
@engine.on(Events.ITERATION_STARTED(once=50))
def call_once(engine):
# do something on 50th iteration
# d) "before" and "after" event filter
@engine.on(Events.EPOCH_STARTED(before=30, after=10))
def call_before(engine):
# do something in 11 to 29 epoch
# e) Mixing "every" and "before" / "after" event filters
@engine.on(Events.EPOCH_STARTED(every=5, before=25, after=8))
def call_every_itr_before_after(engine):
# do something on 9, 14, 19, 24 epochs
Event filter function `event_filter` accepts as input `engine` and `event` and should return True/False.
Argument `event` is the value of iteration or epoch, depending on which type of Events the function is passed.
Since v0.4.0, user can also combine events with `|`-operator:
.. code-block:: python
events = Events.STARTED | Events.COMPLETED | Events.ITERATION_STARTED(every=3)
engine = ...
@engine.on(events)
def call_on_events(engine):
# do something
Since v0.4.0, custom events defined by user should inherit from :class:`~ignite.engine.events.EventEnum` :
.. code-block:: python
class CustomEvents(EventEnum):
FOO_EVENT = "foo_event"
BAR_EVENT = "bar_event"
"""
EPOCH_STARTED = "epoch_started"
"""triggered when the epoch is started."""
EPOCH_COMPLETED = "epoch_completed"
"""Event attribute indicating epoch is ended."""
STARTED = "started"
"""triggered when engine's run is started."""
COMPLETED = "completed"
"""triggered when engine's run is completed"""
ITERATION_STARTED = "iteration_started"
"""triggered when an iteration is started."""
ITERATION_COMPLETED = "iteration_completed"
"""triggered when the iteration is ended."""
EXCEPTION_RAISED = "exception_raised"
"""triggered when an exception is encountered."""
GET_BATCH_STARTED = "get_batch_started"
"""triggered before next batch is fetched."""
GET_BATCH_COMPLETED = "get_batch_completed"
"""triggered after the batch is fetched."""
DATALOADER_STOP_ITERATION = "dataloader_stop_iteration"
"""engine's specific event triggered when dataloader has no more data to provide"""
TERMINATE = "terminate"
"""triggered when the run is about to end completely, after receiving terminate() call."""
TERMINATE_SINGLE_EPOCH = "terminate_single_epoch"
"""triggered when the run is about to end the current epoch,
after receiving a terminate_epoch() call."""
INTERRUPT = "interrupt"
"""triggered when the run is interrupted, after receiving interrupt() call."""
def __or__(self, other: Any) -> "EventsList":
return EventsList() | self | other
[docs]class EventsList:
"""Collection of events stacked by operator `__or__`.
.. code-block:: python
events = Events.STARTED | Events.COMPLETED
events |= Events.ITERATION_STARTED(every=3)
engine = ...
@engine.on(events)
def call_on_events(engine):
# do something
or
.. code-block:: python
@engine.on(Events.STARTED | Events.COMPLETED | Events.ITERATION_STARTED(every=3))
def call_on_events(engine):
# do something
"""
def __init__(self) -> None:
self._events: List[Union[Events, CallableEventWithFilter]] = []
def _append(self, event: Union[Events, CallableEventWithFilter]) -> None:
if not isinstance(event, (Events, CallableEventWithFilter)):
raise TypeError(f"Argument event should be Events or CallableEventWithFilter, got: {type(event)}")
self._events.append(event)
def __getitem__(self, item: int) -> Union[Events, CallableEventWithFilter]:
return self._events[item]
def __iter__(self) -> Iterator[Union[Events, CallableEventWithFilter]]:
return iter(self._events)
def __len__(self) -> int:
return len(self._events)
def __or__(self, other: Union[Events, CallableEventWithFilter]) -> "EventsList":
self._append(event=other)
return self
[docs]class State:
"""An object that is used to pass internal and user-defined state between event handlers. By default, state
contains the following attributes:
.. code-block:: python
state.iteration # 1-based, the first iteration is 1
state.epoch # 1-based, the first epoch is 1
state.seed # seed to set at each epoch
state.dataloader # data passed to engine
state.epoch_length # optional length of an epoch
state.max_epochs # number of epochs to run
state.batch # batch passed to `process_function`
state.output # output of `process_function` after a single iteration
state.metrics # dictionary with defined metrics if any
state.times # dictionary with total and per-epoch times fetched on
# keys: Events.EPOCH_COMPLETED.name and Events.COMPLETED.name
Args:
kwargs: keyword arguments to be defined as State attributes.
"""
event_to_attr: Dict[Union[str, "Events", "CallableEventWithFilter"], str] = {
Events.GET_BATCH_STARTED: "iteration",
Events.GET_BATCH_COMPLETED: "iteration",
Events.ITERATION_STARTED: "iteration",
Events.ITERATION_COMPLETED: "iteration",
Events.EPOCH_STARTED: "epoch",
Events.EPOCH_COMPLETED: "epoch",
Events.STARTED: "epoch",
Events.COMPLETED: "epoch",
}
def __init__(self, **kwargs: Any) -> None:
self.iteration = 0
self.epoch = 0
self.epoch_length: Optional[int] = None
self.max_epochs: Optional[int] = None
self.output: Optional[int] = None
self.batch: Optional[int] = None
self.metrics: Dict[str, Any] = {}
self.dataloader: Optional[Union[DataLoader, Iterable[Any]]] = None
self.seed: Optional[int] = None
self.times: Dict[str, Optional[float]] = {
Events.EPOCH_COMPLETED.name: None,
Events.COMPLETED.name: None,
}
for k, v in kwargs.items():
setattr(self, k, v)
self._update_attrs()
def _update_attrs(self) -> None:
for value in self.event_to_attr.values():
if not hasattr(self, value):
setattr(self, value, 0)
[docs] def get_event_attrib_value(self, event_name: Union[str, Events, CallableEventWithFilter]) -> int:
"""Get the value of Event attribute with given `event_name`."""
if event_name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{event_name}'")
return getattr(self, State.event_to_attr[event_name])
def __repr__(self) -> str:
s = "State:\n"
for attr, value in self.__dict__.items():
if not isinstance(value, (numbers.Number, str)):
value = type(value)
s += f"\t{attr}: {value}\n"
return s
[docs]class RemovableEventHandle:
"""A weakref handle to remove a registered event.
A handle that may be used to remove a registered event handler via the
remove method, with-statement, or context manager protocol. Returned from
:meth:`~ignite.engine.engine.Engine.add_event_handler`.
Args:
event_name: Registered event name.
handler: Registered event handler, stored as weakref.
engine: Target engine, stored as weakref.
Examples:
.. code-block:: python
engine = Engine()
def print_epoch(engine):
print(f"Epoch: {engine.state.epoch}")
with engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch):
# print_epoch handler registered for a single run
engine.run(data)
# print_epoch handler is now unregistered
"""
def __init__(
self, event_name: Union[CallableEventWithFilter, Enum, EventsList, Events], handler: Callable, engine: "Engine"
) -> None:
self.event_name = event_name
self.handler = weakref.ref(handler)
self.engine = weakref.ref(engine)
[docs] def remove(self) -> None:
"""Remove handler from engine."""
handler = self.handler()
engine = self.engine()
if handler is None or engine is None:
return
if hasattr(handler, "_parent"):
handler = handler._parent()
if handler is None:
raise RuntimeError(
"Internal error! Please fill an issue on https://github.com/pytorch/ignite/issues "
"if encounter this error. Thank you!"
)
if isinstance(self.event_name, EventsList):
for e in self.event_name:
if engine.has_event_handler(handler, e):
engine.remove_event_handler(handler, e)
else:
if engine.has_event_handler(handler, self.event_name):
engine.remove_event_handler(handler, self.event_name)
def __enter__(self) -> "RemovableEventHandle":
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self.remove()