Shortcuts

EMAHandler#

class ignite.handlers.ema_handler.EMAHandler(model, momentum=0.0002, momentum_warmup=None, warmup_iters=None)[source]#

Exponential moving average (EMA) handler can be used to compute a smoothed version of model. The EMA model is updated as follows:

θEMA,t+1=(1λ)θEMA,t+λθt\theta_{\text{EMA}, t+1} = (1 - \lambda) \cdot \theta_{\text{EMA}, t} + \lambda \cdot \theta_{t}

where θEMA,t\theta_{\text{EMA}, t} and θt\theta_{t} are the EMA weights and online model weights at tt-th iteration, respectively; λ\lambda is the update momentum. The handler allows for linearly warming up the momentum in the beginning when training process is not stable. Current momentum can be retrieved from Engine.state.ema_momentum.

Parameters
  • model (torch.nn.modules.module.Module) – the online model for which an EMA model will be computed. If model is DataParallel or DistributedDataParallel, the EMA smoothing will be applied to model.module .

  • momentum (float) – the update momentum after warmup phase, should be float in range (0,1)\left(0, 1 \right).

  • momentum_warmup (Optional[float]) – the initial update momentum during warmup phase, the value should be smaller than momentum. Momentum will linearly increase from this value to momentum in warmup_iters iterations. If None, no warmup will be performed.

  • warmup_iters (Optional[int]) – iterations of warmup. If None, no warmup will be performed.

Return type

None

ema_model#

the exponential moving averaged model.

model#

the online model that is tracked by EMAHandler. It is model.module if model in the initialization method is an instance of DistributedDataParallel.

momentum#

the update momentum after warmup phase.

momentum_warmup#

the initial update momentum.

warmup_iters#

number of warmup iterations.

Note

The EMA model is already in eval mode. If model in the arguments is an nn.Module or DistributedDataParallel, the EMA model is an nn.Module and it is on the same device as the online model. If the model is an nn.DataParallel, then the EMA model is an nn.DataParallel.

Note

It is recommended to initialize and use an EMA handler in following steps:

  1. Initialize model (nn.Module or DistributedDataParallel) and ema_handler (EMAHandler).

  2. Build trainer (ignite.engine.Engine).

  3. Resume from checkpoint for model and ema_handler.ema_model.

  4. Attach ema_handler to trainer.

Examples

device = torch.device("cuda:0")
model = nn.Linear(2, 1).to(device)
# update the ema every 5 iterations
ema_handler = EMAHandler(
    model, momentum=0.0002, momentum_warmup=0.0001, warmup_iters=10000)
# get the ema model, which is an instance of nn.Module
ema_model = ema_handler.ema_model
trainer = Engine(train_step_fn)
to_load = {"model": model, "ema_model", ema_model, "trainer", trainer}
if resume_from is not None:
    Checkpoint.load_objects(to_load, checkpoint=resume_from)

# update the EMA model every 5 iterations
ema_handler.attach(trainer, name="ema_momentum", event=Events.ITERATION_COMPLETED(every=5))

# add other handlers
to_save = to_load
ckpt_handler = Checkpoint(to_save, DiskSaver(...), ...)
trainer.add_event_handler(Events.EPOCH_COMPLETED, ckpt_handler)

# current momentum can be retrieved from engine.state,
# the attribute name is the `name` parameter used in the attach function
@trainer.on(Events.ITERATION_COMPLETED):
def print_ema_momentum(engine):
    print(f"current momentum: {engine.state.ema_momentum}"

# use ema model for validation
val_step_fn = get_val_step_fn(ema_model)
evaluator = Engine(val_step_fn)

@trainer.on(Events.EPOCH_COMPLETED)
def run_validation(engine):
    engine.run(val_data_loader)

trainer.run(...)

The following example shows how to attach two handlers to the same trainer:

generator = build_generator(...)
discriminator = build_discriminator(...)

gen_handler = EMAHandler(generator)
disc_handler = EMAHandler(discriminator)

step_fn = get_step_fn(...)
engine = Engine(step_fn)

# update EMA model of generator every 1 iteration
gen_handler.attach(engine, "gen_ema_momentum", event=Events.ITERATION_COMPLETED)
# update EMA model of discriminator every 2 iteration
disc_handler.attach(engine, "dis_ema_momentum", event=Events.ITERATION_COMPLETED(every=2))

@engine.on(Events.ITERATION_COMPLETED)
def print_ema_momentum(engine):
    print(f"current momentum for generator: {engine.state.gen_ema_momentum}")
    print(f"current momentum for discriminator: {engine.state.disc_ema_momentum}")

engine.run(...)

New in version 0.4.6.

Methods

attach

Attach the handler to engine.

attach(engine, name='ema_momentum', event=Events.ITERATION_COMPLETED)[source]#

Attach the handler to engine. After the handler is attached, the Engine.state will add an new attribute with name name. Then, current momentum can be retrieved by from Engine.state when the engine runs.

Parameters
Return type

None