Shortcuts

EMAHandler#

class ignite.handlers.ema_handler.EMAHandler(model, momentum=0.0002, momentum_warmup=None, warmup_iters=None, handle_buffers='copy')[source]#

Exponential moving average (EMA) handler can be used to compute a smoothed version of model. The EMA model is updated as follows:

θEMA,t+1=(1λ)θEMA,t+λθt\theta_{\text{EMA}, t+1} = (1 - \lambda) \cdot \theta_{\text{EMA}, t} + \lambda \cdot \theta_{t}

where θEMA,t\theta_{\text{EMA}, t} and θt\theta_{t} are the EMA weights and online model weights at tt-th iteration, respectively; λ\lambda is the update momentum. Current momentum can be retrieved from Engine.state.ema_momentum.

Parameters
  • model (Module) – the online model for which an EMA model will be computed. If model is DataParallel or DistributedDataParallel, the EMA smoothing will be applied to model.module .

  • momentum (float) – the update momentum after warmup phase, should be float in range (0,1)\left(0, 1 \right).

  • momentum_warmup (Optional[float]) – the initial update momentum during warmup phase.

  • warmup_iters (Optional[int]) – iterations of warmup.

  • handle_buffers (str) – how to handle model buffers during training. There are three options: 1. “copy” means copying the buffers of the online model; 2. “update” means applying EMA to the buffers of the online model; 3. “ema_train” means set the EMA model to train mode and skip copying or updating the buffers.

ema_model#

the exponential moving averaged model.

model#

the online model that is tracked by EMAHandler. It is model.module if model in the initialization method is an instance of DistributedDataParallel.

momentum#

the update momentum.

handle_buffers#

how to handle model buffers during training.

Note

The EMA model is already in eval mode if handle_buffers is “copy” or “update”. If model in the arguments is an nn.Module or DistributedDataParallel, the EMA model is an nn.Module and it is on the same device as the online model. If the model is an nn.DataParallel, then the EMA model is an nn.DataParallel.

Note

It is recommended to initialize and use an EMA handler in following steps:

  1. Initialize model (nn.Module or DistributedDataParallel) and ema_handler (EMAHandler).

  2. Build trainer (ignite.engine.Engine).

  3. Resume from checkpoint for model and ema_handler.ema_model.

  4. Attach ema_handler to trainer.

Examples

device = torch.device("cuda:0")
model = nn.Linear(2, 1).to(device)
# update the ema every 5 iterations
ema_handler = EMAHandler(model, momentum=0.0002)
# get the ema model, which is an instance of nn.Module
ema_model = ema_handler.ema_model
trainer = Engine(train_step_fn)
to_load = {"model": model, "ema_model", ema_model, "trainer", trainer}
if resume_from is not None:
    Checkpoint.load_objects(to_load, checkpoint=resume_from)

# update the EMA model every 5 iterations
ema_handler.attach(trainer, name="ema_momentum", event=Events.ITERATION_COMPLETED(every=5))

# add other handlers
to_save = to_load
ckpt_handler = Checkpoint(to_save, DiskSaver(...), ...)
trainer.add_event_handler(Events.EPOCH_COMPLETED, ckpt_handler)

# current momentum can be retrieved from engine.state,
# the attribute name is the `name` parameter used in the attach function
@trainer.on(Events.ITERATION_COMPLETED):
def print_ema_momentum(engine):
    print(f"current momentum: {engine.state.ema_momentum}"

# use ema model for validation
val_step_fn = get_val_step_fn(ema_model)
evaluator = Engine(val_step_fn)

@trainer.on(Events.EPOCH_COMPLETED)
def run_validation(engine):
    engine.run(val_data_loader)

trainer.run(...)

The following example shows how to perform warm-up to the EMA momentum:

device = torch.device("cuda:0")
model = nn.Linear(2, 1).to(device)
# linearly change the EMA momentum from 0.2 to 0.002 in the first 100 iterations,
# then keep a constant EMA momentum of 0.002 afterwards
ema_handler = EMAHandler(model, momentum=0.002, momentum_warmup=0.2, warmup_iters=100)
engine = Engine(step_fn)
ema_handler.attach(engine, name="ema_momentum")
engine.run(...)

The following example shows how to attach two handlers to the same trainer:

generator = build_generator(...)
discriminator = build_discriminator(...)

gen_handler = EMAHandler(generator)
disc_handler = EMAHandler(discriminator)

step_fn = get_step_fn(...)
engine = Engine(step_fn)

# update EMA model of generator every 1 iteration
gen_handler.attach(engine, "gen_ema_momentum", event=Events.ITERATION_COMPLETED)
# update EMA model of discriminator every 2 iteration
disc_handler.attach(engine, "dis_ema_momentum", event=Events.ITERATION_COMPLETED(every=2))

@engine.on(Events.ITERATION_COMPLETED)
def print_ema_momentum(engine):
    print(f"current momentum for generator: {engine.state.gen_ema_momentum}")
    print(f"current momentum for discriminator: {engine.state.disc_ema_momentum}")

engine.run(...)

New in version 0.4.6.

Methods

attach

Attach the handler to engine.

attach(engine, name='ema_momentum', warn_if_exists=True, event=Events.ITERATION_COMPLETED)[source]#

Attach the handler to engine. After the handler is attached, the Engine.state will add an new attribute with name name if the attribute does not exist. Then, the current momentum can be retrieved from Engine.state when the engine runs.

Note

There are two cases where a momentum with name name already exists: 1. the engine has loaded its state dict after resuming. In this case, there is no need to initialize the momentum again, and users can set warn_if_exists to False to suppress the warning message; 2. another handler has created a state attribute with the same name. In this case, users should choose another name for the ema momentum.

Parameters
  • engine (Engine) – trainer to which the handler will be attached.

  • name (str) – attribute name for retrieving EMA momentum from Engine.state. It should be a unique name since a trainer can have multiple EMA handlers.

  • warn_if_exists (bool) – if True, a warning will be thrown if the momentum with name name already exists.

  • event (Union[str, Events, CallableEventWithFilter, EventsList]) – event when the EMA momentum and EMA model are updated.

Return type

None