Source code for

import numbers
import warnings
import weakref
from enum import Enum
from types import DynamicClassAttribute
from typing import Callable, Optional, Union

from ignite.engine.utils import _check_signature

__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State"]

class CallableEventWithFilter:
    """Single Event containing a filter, specifying whether the event should
    be run at the current event (if the event type is correct)

        value (str): The actual enum value. Only needed for internal use. Do not touch!
        event_filter (callable): A function taking the engine and the current event value as input and returning a
            boolean to indicate whether this event should be executed. Defaults to None, which will result to a
            function that always returns `True`
        name (str, optional): The enum-name of the current object. Only needed for internal use. Do not touch!


    def __init__(self, value: str, event_filter: Optional[Callable] = None, name=None):
        if event_filter is None:
            event_filter = CallableEventWithFilter.default_event_filter
        self.filter = event_filter

        if not hasattr(self, "_value_"):
            self._value_ = value

        if not hasattr(self, "_name_") and name is not None:
            self._name_ = name

    # copied to be compatible to enum
    def name(self):
        """The name of the Enum member."""
        return self._name_

    def value(self):
        """The value of the Enum member."""
        return self._value_

    def __call__(
        self, event_filter: Optional[Callable] = None, every: Optional[int] = None, once: Optional[int] = None
    ) -> "CallableEventWithFilter":
        Makes the event class callable and accepts either an arbitrary callable as filter
        (which must take in the engine and current event value and return a boolean) or an every or once value

            event_filter (callable, optional): a filter function to check if the event should be executed when
                the event type was fired
            every (int, optional): a value specifying how often the event should be fired
            once (int, optional): a value specifying when the event should be fired (if only once)

            CallableEventWithFilter: A new event having the same value but a different filter function

        if not ((event_filter is not None) ^ (every is not None) ^ (once is not None)):
            raise ValueError("Only one of the input arguments should be specified")

        if (event_filter is not None) and not callable(event_filter):
            raise TypeError("Argument event_filter should be a callable")

        if (every is not None) and not (isinstance(every, numbers.Integral) and every > 0):
            raise ValueError("Argument every should be integer and greater than zero")

        if (once is not None) and not (isinstance(once, numbers.Integral) and once > 0):
            raise ValueError("Argument every should be integer and positive")

        if every is not None:
            if every == 1:
                # Just return the event itself
                event_filter = None
                event_filter = self.every_event_filter(every)

        if once is not None:
            event_filter = self.once_event_filter(once)

        # check signature:
        if event_filter is not None:
            _check_signature(event_filter, "event_filter", "engine", "event")

        return CallableEventWithFilter(self.value, event_filter,

    def every_event_filter(every: int) -> Callable:
        def wrapper(engine, event: int) -> bool:
            if event % every == 0:
                return True
            return False

        return wrapper

    def once_event_filter(once: int) -> Callable:
        def wrapper(engine, event: int) -> bool:
            if event == once:
                return True
            return False

        return wrapper

    def default_event_filter(engine, event: int) -> bool:
        return True

    def __str__(self) -> str:
        return "<event=%s, filter=%r>" % (, self.filter)

    def __eq__(self, other):
        if isinstance(other, CallableEventWithFilter):
            return ==
        elif isinstance(other, str):
            return == other
            raise NotImplementedError

    def __hash__(self):
        return hash(self._name_)

    def __or__(self, other):
        return EventsList() | self | other

class CallableEvents(CallableEventWithFilter):
    # For backward compatibility
    def __init__(self, *args, **kwargs):
        super(CallableEvents, self).__init__(*args, **kwargs)
            "Class is deprecated. It will be removed in 0.5.0. "
            "Please, use ignite.engine.EventEnum instead",

class EventEnum(CallableEventWithFilter, Enum):

[docs]class Events(EventEnum): """Events that are fired by the :class:`~ignite.engine.Engine` during execution. Built-in events: - STARTED : triggered when engine's run is started - EPOCH_STARTED : triggered when the epoch is started - GET_BATCH_STARTED : triggered before next batch is fetched - GET_BATCH_COMPLETED : triggered after the batch is fetched - ITERATION_STARTED : triggered when an iteration is started - ITERATION_COMPLETED : triggered when the iteration is ended - DATALOADER_STOP_ITERATION : engine's specific event triggered when dataloader has no more data to provide - EXCEPTION_RAISED : triggered when an exception is encountered - TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch, after receiving :meth:`~ignite.engine.Engine.terminate_epoch()` call. - TERMINATE : triggered when the run is about to end completely, after receiving :meth:`~ignite.engine.Engine.terminate()` call. - EPOCH_COMPLETED : triggered when the epoch is ended - COMPLETED : triggered when engine's run is completed Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine: .. code-block:: python engine = Engine() # a) custom event filter def custom_event_filter(engine, event): if event in [1, 2, 5, 10, 50, 100]: return True return False @engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter)) def call_on_special_event(engine): # do something on 1, 2, 5, 10, 50, 100 iterations # b) "every" event filter @engine.on(Events.ITERATION_STARTED(every=10)) def call_every(engine): # do something every 10th iteration # c) "once" event filter @engine.on(Events.ITERATION_STARTED(once=50)) def call_once(engine): # do something on 50th iteration Event filter function `event_filter` accepts as input `engine` and `event` and should return True/False. Argument `event` is the value of iteration or epoch, depending on which type of Events the function is passed. Since v0.4.0, user can also combine events with `|`-operator: .. code-block:: python events = Events.STARTED | Events.COMPLETED | Events.ITERATION_STARTED(every=3) engine = ... @engine.on(events) def call_on_events(engine): # do something """ EPOCH_STARTED = "epoch_started" EPOCH_COMPLETED = "epoch_completed" STARTED = "started" COMPLETED = "completed" ITERATION_STARTED = "iteration_started" ITERATION_COMPLETED = "iteration_completed" EXCEPTION_RAISED = "exception_raised" GET_BATCH_STARTED = "get_batch_started" GET_BATCH_COMPLETED = "get_batch_completed" DATALOADER_STOP_ITERATION = "dataloader_stop_iteration" TERMINATE = "terminate" TERMINATE_SINGLE_EPOCH = "terminate_single_epoch" def __or__(self, other): return EventsList() | self | other
class EventsList: """Collection of events stacked by operator `__or__`. .. code-block:: python events = Events.STARTED | Events.COMPLETED events |= Events.ITERATION_STARTED(every=3) engine = ... @engine.on(events) def call_on_events(engine): # do something or .. code-block:: python @engine.on(Events.STARTED | Events.COMPLETED | Events.ITERATION_STARTED(every=3)) def call_on_events(engine): # do something """ def __init__(self): self._events = [] def _append(self, event: Union[Events, CallableEventWithFilter]): if not isinstance(event, (Events, CallableEventWithFilter)): raise ValueError("Argument event should be Events or CallableEventWithFilter, got: {}".format(type(event))) self._events.append(event) def __getitem__(self, item): return self._events[item] def __iter__(self): return iter(self._events) def __len__(self): return len(self._events) def __or__(self, other: Union[Events, CallableEventWithFilter]): self._append(event=other) return self
[docs]class State: """An object that is used to pass internal and user-defined state between event handlers. By default, state contains the following attributes: .. code-block:: python state.iteration # 1-based, the first iteration is 1 state.epoch # 1-based, the first epoch is 1 state.seed # seed to set at each epoch state.dataloader # data passed to engine state.epoch_length # optional length of an epoch state.max_epochs # number of epochs to run state.batch # batch passed to `process_function` state.output # output of `process_function` after a single iteration state.metrics # dictionary with defined metrics if any state.times # dictionary with total and per-epoch times fetched on # keys: and """ event_to_attr = { Events.GET_BATCH_STARTED: "iteration", Events.GET_BATCH_COMPLETED: "iteration", Events.ITERATION_STARTED: "iteration", Events.ITERATION_COMPLETED: "iteration", Events.EPOCH_STARTED: "epoch", Events.EPOCH_COMPLETED: "epoch", Events.STARTED: "epoch", Events.COMPLETED: "epoch", } def __init__(self, **kwargs): self.iteration = 0 self.epoch = 0 self.epoch_length = None self.max_epochs = None self.output = None self.batch = None self.metrics = {} self.dataloader = None self.seed = None self.times = { None, None} for k, v in kwargs.items(): setattr(self, k, v) self._update_attrs() def _update_attrs(self): for value in self.event_to_attr.values(): if not hasattr(self, value): setattr(self, value, 0) def get_event_attrib_value(self, event_name: Union[CallableEventWithFilter, Enum]) -> int: if event_name not in State.event_to_attr: raise RuntimeError("Unknown event name '{}'".format(event_name)) return getattr(self, State.event_to_attr[event_name]) def __repr__(self) -> str: s = "State:\n" for attr, value in self.__dict__.items(): if not isinstance(value, (numbers.Number, str)): value = type(value) s += "\t{}: {}\n".format(attr, value) return s
[docs]class RemovableEventHandle: """A weakref handle to remove a registered event. A handle that may be used to remove a registered event handler via the remove method, with-statement, or context manager protocol. Returned from :meth:`~ignite.engine.Engine.add_event_handler`. Args: event_name: Registered event name. handler: Registered event handler, stored as weakref. engine: Target engine, stored as weakref. Example usage: .. code-block:: python engine = Engine() def print_epoch(engine): print("Epoch: {}".format(engine.state.epoch)) with engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch): # print_epoch handler registered for a single run # print_epoch handler is now unregistered """ def __init__(self, event_name: Union[CallableEventWithFilter, Enum, EventsList], handler: Callable, engine): self.event_name = event_name self.handler = weakref.ref(handler) self.engine = weakref.ref(engine)
[docs] def remove(self) -> None: """Remove handler from engine.""" handler = self.handler() engine = self.engine() if handler is None or engine is None: return if isinstance(self.event_name, EventsList): for e in self.event_name: if engine.has_event_handler(handler, e): engine.remove_event_handler(handler, e) else: if engine.has_event_handler(handler, self.event_name): engine.remove_event_handler(handler, self.event_name)
def __enter__(self): return self def __exit__(self, *args, **kwargs) -> None: self.remove()

© Copyright 2024, PyTorch-Ignite Contributors. Last updated on 07/17/2024, 10:10:30 AM.

Built with Sphinx using a theme provided by Read the Docs.