EMAHandler#
- class ignite.handlers.ema_handler.EMAHandler(model, momentum=0.0002, momentum_warmup=None, warmup_iters=None)[source]#
Exponential moving average (EMA) handler can be used to compute a smoothed version of model. The EMA model is updated as follows:
where and are the EMA weights and online model weights at -th iteration, respectively; is the update momentum. The handler allows for linearly warming up the momentum in the beginning when training process is not stable. Current momentum can be retrieved from
Engine.state.ema_momentum
.- Parameters
model (Module) – the online model for which an EMA model will be computed. If
model
isDataParallel
orDistributedDataParallel
, the EMA smoothing will be applied tomodel.module
.momentum (float) – the update momentum after warmup phase, should be float in range .
momentum_warmup (Optional[float]) – the initial update momentum during warmup phase, the value should be smaller than
momentum
. Momentum will linearly increase from this value tomomentum
inwarmup_iters
iterations. IfNone
, no warmup will be performed.warmup_iters (Optional[int]) – iterations of warmup. If
None
, no warmup will be performed.
- ema_model#
the exponential moving averaged model.
- model#
the online model that is tracked by EMAHandler. It is
model.module
ifmodel
in the initialization method is an instance ofDistributedDataParallel
.
- momentum#
the update momentum after warmup phase.
- momentum_warmup#
the initial update momentum.
- warmup_iters#
number of warmup iterations.
Note
The EMA model is already in
eval
mode. If model in the arguments is annn.Module
orDistributedDataParallel
, the EMA model is annn.Module
and it is on the same device as the online model. If the model is annn.DataParallel
, then the EMA model is annn.DataParallel
.Note
It is recommended to initialize and use an EMA handler in following steps:
Initialize
model
(nn.Module
orDistributedDataParallel
) andema_handler
(EMAHandler
).Build
trainer
(ignite.engine.Engine
).Resume from checkpoint for
model
andema_handler.ema_model
.Attach
ema_handler
totrainer
.
Examples
device = torch.device("cuda:0") model = nn.Linear(2, 1).to(device) # update the ema every 5 iterations ema_handler = EMAHandler( model, momentum=0.0002, momentum_warmup=0.0001, warmup_iters=10000) # get the ema model, which is an instance of nn.Module ema_model = ema_handler.ema_model trainer = Engine(train_step_fn) to_load = {"model": model, "ema_model", ema_model, "trainer", trainer} if resume_from is not None: Checkpoint.load_objects(to_load, checkpoint=resume_from) # update the EMA model every 5 iterations ema_handler.attach(trainer, name="ema_momentum", event=Events.ITERATION_COMPLETED(every=5)) # add other handlers to_save = to_load ckpt_handler = Checkpoint(to_save, DiskSaver(...), ...) trainer.add_event_handler(Events.EPOCH_COMPLETED, ckpt_handler) # current momentum can be retrieved from engine.state, # the attribute name is the `name` parameter used in the attach function @trainer.on(Events.ITERATION_COMPLETED): def print_ema_momentum(engine): print(f"current momentum: {engine.state.ema_momentum}" # use ema model for validation val_step_fn = get_val_step_fn(ema_model) evaluator = Engine(val_step_fn) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(engine): engine.run(val_data_loader) trainer.run(...)
The following example shows how to attach two handlers to the same trainer:
generator = build_generator(...) discriminator = build_discriminator(...) gen_handler = EMAHandler(generator) disc_handler = EMAHandler(discriminator) step_fn = get_step_fn(...) engine = Engine(step_fn) # update EMA model of generator every 1 iteration gen_handler.attach(engine, "gen_ema_momentum", event=Events.ITERATION_COMPLETED) # update EMA model of discriminator every 2 iteration disc_handler.attach(engine, "dis_ema_momentum", event=Events.ITERATION_COMPLETED(every=2)) @engine.on(Events.ITERATION_COMPLETED) def print_ema_momentum(engine): print(f"current momentum for generator: {engine.state.gen_ema_momentum}") print(f"current momentum for discriminator: {engine.state.disc_ema_momentum}") engine.run(...)
New in version 0.4.6.
Methods
Attach the handler to engine.
- attach(engine, name='ema_momentum', event=Events.ITERATION_COMPLETED)[source]#
Attach the handler to engine. After the handler is attached, the
Engine.state
will add an new attribute with namename
. Then, current momentum can be retrieved by fromEngine.state
when the engine runs.- Parameters
engine (Engine) – trainer to which the handler will be attached.
name (str) – attribute name for retrieving EMA momentum from
Engine.state
. It should be a unique name since a trainer can have multiple EMA handlers.event (Union[str, Events, CallableEventWithFilter, EventsList]) – event when the EMA momentum and EMA model are updated.
- Return type
None