EMAHandler#
- class ignite.handlers.ema_handler.EMAHandler(model, momentum=0.0002, momentum_warmup=None, warmup_iters=None, handle_buffers='copy')[source]#
Exponential moving average (EMA) handler can be used to compute a smoothed version of model. The EMA model is updated as follows:
where and are the EMA weights and online model weights at -th iteration, respectively; is the update momentum. Current momentum can be retrieved from
Engine.state.ema_momentum
.- Parameters
model (Module) – the online model for which an EMA model will be computed. If
model
isDataParallel
orDistributedDataParallel
, the EMA smoothing will be applied tomodel.module
.momentum (float) – the update momentum after warmup phase, should be float in range .
momentum_warmup (Optional[float]) – the initial update momentum during warmup phase.
handle_buffers (str) – how to handle model buffers during training. There are three options: 1. “copy” means copying the buffers of the online model; 2. “update” means applying EMA to the buffers of the online model; 3. “ema_train” means set the EMA model to
train
mode and skip copying or updating the buffers.
- ema_model#
the exponential moving averaged model.
- model#
the online model that is tracked by EMAHandler. It is
model.module
ifmodel
in the initialization method is an instance ofDistributedDataParallel
.
- momentum#
the update momentum.
- handle_buffers#
how to handle model buffers during training.
Note
The EMA model is already in
eval
mode ifhandle_buffers
is “copy” or “update”. If model in the arguments is annn.Module
orDistributedDataParallel
, the EMA model is annn.Module
and it is on the same device as the online model. If the model is annn.DataParallel
, then the EMA model is annn.DataParallel
.Note
It is recommended to initialize and use an EMA handler in following steps:
Initialize
model
(nn.Module
orDistributedDataParallel
) andema_handler
(EMAHandler
).Build
trainer
(ignite.engine.Engine
).Resume from checkpoint for
model
andema_handler.ema_model
.Attach
ema_handler
totrainer
.
Examples
device = torch.device("cuda:0") model = nn.Linear(2, 1).to(device) # update the ema every 5 iterations ema_handler = EMAHandler(model, momentum=0.0002) # get the ema model, which is an instance of nn.Module ema_model = ema_handler.ema_model trainer = Engine(train_step_fn) to_load = {"model": model, "ema_model", ema_model, "trainer", trainer} if resume_from is not None: Checkpoint.load_objects(to_load, checkpoint=resume_from) # update the EMA model every 5 iterations ema_handler.attach(trainer, name="ema_momentum", event=Events.ITERATION_COMPLETED(every=5)) # add other handlers to_save = to_load ckpt_handler = Checkpoint(to_save, DiskSaver(...), ...) trainer.add_event_handler(Events.EPOCH_COMPLETED, ckpt_handler) # current momentum can be retrieved from engine.state, # the attribute name is the `name` parameter used in the attach function @trainer.on(Events.ITERATION_COMPLETED): def print_ema_momentum(engine): print(f"current momentum: {engine.state.ema_momentum}" # use ema model for validation val_step_fn = get_val_step_fn(ema_model) evaluator = Engine(val_step_fn) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(engine): engine.run(val_data_loader) trainer.run(...)
The following example shows how to perform warm-up to the EMA momentum:
device = torch.device("cuda:0") model = nn.Linear(2, 1).to(device) # linearly change the EMA momentum from 0.2 to 0.002 in the first 100 iterations, # then keep a constant EMA momentum of 0.002 afterwards ema_handler = EMAHandler(model, momentum=0.002, momentum_warmup=0.2, warmup_iters=100) engine = Engine(step_fn) ema_handler.attach(engine, name="ema_momentum") engine.run(...)
The following example shows how to attach two handlers to the same trainer:
generator = build_generator(...) discriminator = build_discriminator(...) gen_handler = EMAHandler(generator) disc_handler = EMAHandler(discriminator) step_fn = get_step_fn(...) engine = Engine(step_fn) # update EMA model of generator every 1 iteration gen_handler.attach(engine, "gen_ema_momentum", event=Events.ITERATION_COMPLETED) # update EMA model of discriminator every 2 iteration disc_handler.attach(engine, "dis_ema_momentum", event=Events.ITERATION_COMPLETED(every=2)) @engine.on(Events.ITERATION_COMPLETED) def print_ema_momentum(engine): print(f"current momentum for generator: {engine.state.gen_ema_momentum}") print(f"current momentum for discriminator: {engine.state.disc_ema_momentum}") engine.run(...)
New in version 0.4.6.
Methods
Attach the handler to engine.
- attach(engine, name='ema_momentum', warn_if_exists=True, event=Events.ITERATION_COMPLETED)[source]#
Attach the handler to engine. After the handler is attached, the
Engine.state
will add an new attribute with namename
if the attribute does not exist. Then, the current momentum can be retrieved fromEngine.state
when the engine runs.Note
There are two cases where a momentum with name
name
already exists: 1. the engine has loaded its state dict after resuming. In this case, there is no need to initialize the momentum again, and users can setwarn_if_exists
to False to suppress the warning message; 2. another handler has created a state attribute with the same name. In this case, users should choose another name for the ema momentum.- Parameters
engine (Engine) – trainer to which the handler will be attached.
name (str) – attribute name for retrieving EMA momentum from
Engine.state
. It should be a unique name since a trainer can have multiple EMA handlers.warn_if_exists (bool) – if True, a warning will be thrown if the momentum with name
name
already exists.event (Union[str, Events, CallableEventWithFilter, EventsList]) – event when the EMA momentum and EMA model are updated.
- Return type
None