ignite.contrib.handlers#
Contribution module of handlers
custom_events#
- class ignite.contrib.handlers.custom_events.CustomPeriodicEvent(n_iterations=None, n_epochs=None)[source]#
Handler to define a custom periodic events as a number of elapsed iterations/epochs for an engine.
When custom periodic event is created and attached to an engine, the following events are fired: 1) K iterations is specified: - Events.ITERATIONS_<K>_STARTED - Events.ITERATIONS_<K>_COMPLETED
1) K epochs is specified: - Events.EPOCHS_<K>_STARTED - Events.EPOCHS_<K>_COMPLETED
Examples:
from ignite.engine import Engine, Events from ignite.contrib.handlers import CustomPeriodicEvent # Let's define an event every 1000 iterations cpe1 = CustomPeriodicEvent(n_iterations=1000) cpe1.attach(trainer) # Let's define an event every 10 epochs cpe2 = CustomPeriodicEvent(n_epochs=10) cpe2.attach(trainer) @trainer.on(cpe1.Events.ITERATIONS_1000_COMPLETED) def on_every_1000_iterations(engine): # run a computation after 1000 iterations # ... print(engine.state.iterations_1000) @trainer.on(cpe2.Events.EPOCHS_10_STARTED) def on_every_10_epochs(engine): # run a computation every 10 epochs # ... print(engine.state.epochs_10)
param_scheduler#
- class ignite.contrib.handlers.param_scheduler.ConcatScheduler(schedulers, durations, save_history=False)[source]#
Concat a list of parameter schedulers.
The ConcatScheduler goes through a list of schedulers given by schedulers. Duration of each scheduler is defined by durations list of integers.
- Parameters
schedulers (list of ParamScheduler) – list of parameter schedulers.
durations (list of int) – list of number of events that lasts a parameter scheduler from schedulers.
save_history (bool, optional) – whether to log the parameter values to engine.state.param_history, (default=False).
Examples:
from ignite.contrib.handlers.param_scheduler import ConcatScheduler from ignite.contrib.handlers.param_scheduler import LinearCyclicalScheduler from ignite.contrib.handlers.param_scheduler import CosineAnnealingScheduler scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.1, end_value=0.5, cycle_size=60) scheduler_2 = CosineAnnealingScheduler(optimizer, "lr", start_value=0.5, end_value=0.01, cycle_size=60) combined_scheduler = ConcatScheduler(schedulers=[scheduler_1, scheduler_2], durations=[30, ]) trainer.add_event_handler(Events.ITERATION_STARTED, combined_scheduler) # # Sets the Learning rate linearly from 0.1 to 0.5 over 30 iterations. Then # starts an annealing schedule from 0.5 to 0.01 over 60 iterations. # The annealing cycles are repeated indefinitely. #
- load_state_dict(state_dict)[source]#
Copies parameters from
state_dict
into this ConcatScheduler.- Parameters
state_dict (dict) – a dict containing parameters.
- classmethod simulate_values(num_events, schedulers, durations, param_names=None, **kwargs)[source]#
Method to simulate scheduled values during num_events events.
- Parameters
num_events (int) – number of events during the simulation.
schedulers (list of ParamScheduler) – list of parameter schedulers.
durations (list of int) – list of number of events that lasts a parameter scheduler from schedulers.
param_names (list or tuple of str, optional) – parameter name or list of parameter names to simulate values. By default, the first scheduler’s parameter name is taken.
- Returns
list of [event_index, value_0, value_1, …], where values correspond to param_names.
- class ignite.contrib.handlers.param_scheduler.CosineAnnealingScheduler(optimizer, param_name, start_value, end_value, cycle_size, cycle_mult=1.0, start_value_mult=1.0, end_value_mult=1.0, save_history=False, param_group_index=None)[source]#
Anneals ‘start_value’ to ‘end_value’ over each cycle.
The annealing takes the form of the first half of a cosine wave (as suggested in [Smith17]).
- Parameters
optimizer (torch.optim.Optimizer) – optimizer
param_name (str) – name of optimizer’s parameter to update.
start_value (float) – value at start of cycle.
end_value (float) – value at the end of the cycle.
cycle_size (int) – length of cycle.
cycle_mult (float, optional) – ratio by which to change the cycle_size at the end of each cycle (default=1).
start_value_mult (float, optional) – ratio by which to change the start value at the end of each cycle (default=1.0).
end_value_mult (float, optional) – ratio by which to change the end value at the end of each cycle (default=1.0).
save_history (bool, optional) – whether to log the parameter values to engine.state.param_history, (default=False).
param_group_index (int, optional) – optimizer’s parameters group to use.
Note
If the scheduler is bound to an ‘ITERATION_*’ event, ‘cycle_size’ should usually be the number of batches in an epoch.
Examples:
from ignite.contrib.handlers.param_scheduler import CosineAnnealingScheduler scheduler = CosineAnnealingScheduler(optimizer, 'lr', 1e-1, 1e-3, len(train_loader)) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # # Anneals the learning rate from 1e-1 to 1e-3 over the course of 1 epoch. #
from ignite.contrib.handlers.param_scheduler import CosineAnnealingScheduler from ignite.contrib.handlers.param_scheduler import LinearCyclicalScheduler optimizer = SGD( [ {"params": model.base.parameters(), 'lr': 0.001), {"params": model.fc.parameters(), 'lr': 0.01), ] ) scheduler1 = LinearCyclicalScheduler(optimizer, 'lr', 1e-7, 1e-5, len(train_loader), param_group_index=0) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler1, "lr (base)") scheduler2 = CosineAnnealingScheduler(optimizer, 'lr', 1e-5, 1e-3, len(train_loader), param_group_index=1) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler2, "lr (fc)")
- Smith17
Smith, Leslie N. “Cyclical learning rates for training neural networks.” Applications of Computer Vision (WACV), 2017 IEEE Winter Conference on. IEEE, 2017
- class ignite.contrib.handlers.param_scheduler.CyclicalScheduler(optimizer, param_name, start_value, end_value, cycle_size, cycle_mult=1.0, start_value_mult=1.0, end_value_mult=1.0, save_history=False, param_group_index=None)[source]#
An abstract class for updating an optimizer’s parameter value over a cycle of some size.
- Parameters
optimizer (torch.optim.Optimizer) – optimizer
param_name (str) – name of optimizer’s parameter to update.
start_value (float) – value at start of cycle.
end_value (float) – value at the middle of the cycle.
cycle_size (int) – length of cycle, value should be larger than 1.
cycle_mult (float, optional) – ratio by which to change the cycle_size. at the end of each cycle (default=1.0).
start_value_mult (float, optional) – ratio by which to change the start value at the end of each cycle (default=1.0).
end_value_mult (float, optional) – ratio by which to change the end value at the end of each cycle (default=1.0).
save_history (bool, optional) – whether to log the parameter values to engine.state.param_history, (default=False).
param_group_index (int, optional) – optimizer’s parameters group to use.
Note
If the scheduler is bound to an ‘ITERATION_*’ event, ‘cycle_size’ should usually be the number of batches in an epoch.
- class ignite.contrib.handlers.param_scheduler.LRScheduler(lr_scheduler, save_history=False, **kwds)[source]#
A wrapper class to call torch.optim.lr_scheduler objects as ignite handlers.
- Parameters
lr_scheduler (subclass of torch.optim.lr_scheduler._LRScheduler) – lr_scheduler object to wrap.
save_history (bool, optional) – whether to log the parameter values to engine.state.param_history, (default=False).
from ignite.contrib.handlers.param_scheduler import LRScheduler from torch.optim.lr_scheduler import StepLR step_scheduler = StepLR(optimizer, step_size=3, gamma=0.1) scheduler = LRScheduler(step_scheduler) # In this example, we assume to have installed PyTorch>=1.1.0 # (with new `torch.optim.lr_scheduler` behaviour) and # we attach scheduler to Events.ITERATION_COMPLETED # instead of Events.ITERATION_STARTED to make sure to use # the first lr value from the optimizer, otherwise it is will be skipped: trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
- class ignite.contrib.handlers.param_scheduler.LinearCyclicalScheduler(optimizer, param_name, start_value, end_value, cycle_size, cycle_mult=1.0, start_value_mult=1.0, end_value_mult=1.0, save_history=False, param_group_index=None)[source]#
Linearly adjusts param value to ‘end_value’ for a half-cycle, then linearly adjusts it back to ‘start_value’ for a half-cycle.
- Parameters
optimizer (torch.optim.Optimizer) – optimizer
param_name (str) – name of optimizer’s parameter to update.
start_value (float) – value at start of cycle.
end_value (float) – value at the middle of the cycle.
cycle_size (int) – length of cycle.
cycle_mult (float, optional) – ratio by which to change the cycle_size at the end of each cycle (default=1).
start_value_mult (float, optional) – ratio by which to change the start value at the end of each cycle (default=1.0).
end_value_mult (float, optional) – ratio by which to change the end value at the end of each cycle (default=1.0).
save_history (bool, optional) – whether to log the parameter values to engine.state.param_history, (default=False).
param_group_index (int, optional) – optimizer’s parameters group to use.
Note
If the scheduler is bound to an ‘ITERATION_*’ event, ‘cycle_size’ should usually be the number of batches in an epoch.
Examples:
from ignite.contrib.handlers.param_scheduler import LinearCyclicalScheduler scheduler = LinearCyclicalScheduler(optimizer, 'lr', 1e-3, 1e-1, len(train_loader)) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # # Linearly increases the learning rate from 1e-3 to 1e-1 and back to 1e-3 # over the course of 1 epoch #
- class ignite.contrib.handlers.param_scheduler.ParamGroupScheduler(schedulers, names)[source]#
Scheduler helper to group multiple schedulers into one.
- Parameters
schedulers (list/tuple of ParamScheduler) – list/tuple of parameter schedulers.
optimizer = SGD( [ {"params": model.base.parameters(), 'lr': 0.001), {"params": model.fc.parameters(), 'lr': 0.01), ] ) scheduler1 = LinearCyclicalScheduler(optimizer, 'lr', 1e-7, 1e-5, len(train_loader), param_group_index=0) scheduler2 = CosineAnnealingScheduler(optimizer, 'lr', 1e-5, 1e-3, len(train_loader), param_group_index=1) lr_schedulers = [scheduler1, scheduler2] names = ["lr (base)", "lr (fc)"] scheduler = ParamGroupScheduler(schedulers=lr_schedulers, names=names) # Attach single scheduler to the trainer trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
- load_state_dict(state_dict)[source]#
Copies parameters from
state_dict
into this ParamScheduler.- Parameters
state_dict (dict) – a dict containing parameters.
- class ignite.contrib.handlers.param_scheduler.ParamScheduler(optimizer, param_name, save_history=False, param_group_index=None)[source]#
An abstract class for updating an optimizer’s parameter value during training.
- Parameters
Note
Parameter scheduler works independently of the internal state of the attached optimizer. More precisely, whatever the state of the optimizer (newly created or used by another scheduler) the scheduler sets defined absolute values.
- load_state_dict(state_dict)[source]#
Copies parameters from
state_dict
into this ParamScheduler.- Parameters
state_dict (dict) – a dict containing parameters.
- classmethod plot_values(num_events, **scheduler_kwargs)[source]#
Method to plot simulated scheduled values during num_events events.
This class requires matplotlib package to be installed:
pip install matplotlib
- Parameters
num_events (int) – number of events during the simulation.
**scheduler_kwargs – parameter scheduler configuration kwargs.
- Returns
matplotlib.lines.Line2D
Examples
import matplotlib.pylab as plt plt.figure(figsize=(10, 7)) LinearCyclicalScheduler.plot_values(num_events=50, param_name='lr', start_value=1e-1, end_value=1e-3, cycle_size=10))
- classmethod simulate_values(num_events, **scheduler_kwargs)[source]#
Method to simulate scheduled values during num_events events.
- Parameters
num_events (int) – number of events during the simulation.
**scheduler_kwargs – parameter scheduler configuration kwargs.
- Returns
[event_index, value]
- Return type
list of pairs
Examples:
lr_values = np.array(LinearCyclicalScheduler.simulate_values(num_events=50, param_name='lr', start_value=1e-1, end_value=1e-3, cycle_size=10)) plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate") plt.xlabel("events") plt.ylabel("values") plt.legend()
- class ignite.contrib.handlers.param_scheduler.PiecewiseLinear(optimizer, param_name, milestones_values, save_history=False, param_group_index=None)[source]#
Piecewise linear parameter scheduler
- Parameters
optimizer (torch.optim.Optimizer) – optimizer.
param_name (str) – name of optimizer’s parameter to update.
milestones_values (list of tuples (int, float)) – list of tuples (event index, parameter value) represents milestones and parameter. Milestones should be increasing integers.
save_history (bool, optional) – whether to log the parameter values to engine.state.param_history, (default=False).
param_group_index (int, optional) – optimizer’s parameters group to use.
- Returns
piecewise linear scheduler
- Return type
scheduler = PiecewiseLinear(optimizer, "lr", milestones_values=[(10, 0.5), (20, 0.45), (21, 0.3), (30, 0.1), (40, 0.1)]) # Attach to the trainer trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # # Sets the learning rate to 0.5 over the first 10 iterations, then decreases linearly from 0.5 to 0.45 between # 10th and 20th iterations. Next there is a jump to 0.3 at the 21st iteration and LR decreases linearly # from 0.3 to 0.1 between 21st and 30th iterations and remains 0.1 until the end of the iterations. #
- ignite.contrib.handlers.param_scheduler.create_lr_scheduler_with_warmup(lr_scheduler, warmup_start_value, warmup_end_value, warmup_duration, save_history=False, output_simulated_values=None)[source]#
Helper method to create a learning rate scheduler with a linear warm-up.
- Parameters
lr_scheduler (ParamScheduler or subclass of torch.optim.lr_scheduler._LRScheduler) – learning rate scheduler after the warm-up.
warmup_start_value (float) – learning rate start value of the warm-up phase.
warmup_end_value (float) – learning rate end value of the warm-up phase.
warmup_duration (int) – warm-up phase duration, number of events.
save_history (bool, optional) – whether to log the parameter values to engine.state.param_history, (default=False).
output_simulated_values (list, optional) – optional output of simulated learning rate values. If output_simulated_values is a list of None, e.g. [None] * 100, after the execution it will be filled by 100 simulated learning rate values.
- Returns
learning rate scheduler with linear warm-up.
- Return type
Note
If the first learning rate value provided by lr_scheduler is different from warmup_end_value, an additional event is added after the warm-up phase such that the warm-up ends with warmup_end_value value and then lr_scheduler provides its learning rate values as normally.
Examples
torch_lr_scheduler = ExponentialLR(optimizer=optimizer, gamma=0.98) lr_values = [None] * 100 scheduler = create_lr_scheduler_with_warmup(torch_lr_scheduler, warmup_start_value=0.0, warmup_end_value=0.1, warmup_duration=10, output_simulated_values=lr_values) lr_values = np.array(lr_values) # Plot simulated values plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate") # Attach to the trainer trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
tensorboard_logger#
See tensorboardX mnist example and CycleGAN and EfficientNet notebooks for detailed usage.
- class ignite.contrib.handlers.tensorboard_logger.GradsHistHandler(model, tag=None)[source]#
Helper handler to log model’s gradients as histograms.
Examples
from ignite.contrib.handlers.tensorboard_logger import * # Create a logger tb_logger = TensorboardLogger(log_dir="experiments/tb_logs") # Attach the logger to the trainer to log model's weights norm after each iteration tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.ITERATION_COMPLETED)
- Parameters
model (torch.nn.Module) – model to log weights
tag (str, optional) – common title for all produced plots. For example, ‘generator’
- class ignite.contrib.handlers.tensorboard_logger.GradsScalarHandler(model, reduction=<function norm>, tag=None)[source]#
Helper handler to log model’s gradients as scalars. Handler iterates over the gradients of named parameters of the model, applies reduction function to each parameter produce a scalar and then logs the scalar.
Examples
from ignite.contrib.handlers.tensorboard_logger import * # Create a logger tb_logger = TensorboardLogger(log_dir="experiments/tb_logs") # Attach the logger to the trainer to log model's weights norm after each iteration tb_logger.attach(trainer, log_handler=GradsScalarHandler(model, reduction=torch.norm), event_name=Events.ITERATION_COMPLETED)
- Parameters
model (torch.nn.Module) – model to log weights
reduction (callable) – function to reduce parameters into scalar
tag (str, optional) – common title for all produced plots. For example, ‘generator’
- class ignite.contrib.handlers.tensorboard_logger.OptimizerParamsHandler(optimizer, param_name='lr', tag=None)[source]#
Helper handler to log optimizer parameters
Examples
from ignite.contrib.handlers.tensorboard_logger import * # Create a logger tb_logger = TensorboardLogger(log_dir="experiments/tb_logs") # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
- Parameters
optimizer (torch.optim.Optimizer) – torch optimizer which parameters to log
param_name (str) – parameter name
tag (str, optional) – common title for all produced plots. For example, ‘generator’
- class ignite.contrib.handlers.tensorboard_logger.OutputHandler(tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None)[source]#
Helper handler to log engine’s output and/or metrics
Examples
from ignite.contrib.handlers.tensorboard_logger import * # Create a logger tb_logger = TensorboardLogger(log_dir="experiments/tb_logs") # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # each epoch. We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch # of the `trainer`: tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED)
Example with CustomPeriodicEvent, where model is evaluated every 500 iterations:
from ignite.contrib.handlers import CustomPeriodicEvent cpe = CustomPeriodicEvent(n_iterations=500) cpe.attach(trainer) @trainer.on(cpe.Events.ITERATIONS_500_COMPLETED) def evaluate(engine): evaluator.run(validation_set, max_epochs=1) from ignite.contrib.handlers.tensorboard_logger import * tb_logger = TensorboardLogger(log_dir="experiments/tb_logs") def global_step_transform(*args, **kwargs): return trainer.state.iteration # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # every 500 iterations. Since evaluator engine does not have CustomPeriodicEvent attached to it, we # provide a global_step_transform to return the trainer.state.iteration for the global_step, each time # evaluator metrics are plotted on Tensorboard. tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metrics=["nll", "accuracy"], global_step_transform=global_step_transform), event_name=Events.EPOCH_COMPLETED)
- Parameters
tag (str) – common title for all produced plots. For example, ‘training’
metric_names (list of str, optional) – list of metric names to plot or a string “all” to plot all available metrics.
output_transform (callable, optional) – output transform function to prepare engine.state.output as a number. For example, output_transform = lambda output: output This function can also return a dictionary, e.g {‘loss’: loss1, `another_loss: loss2}` to label the plot with corresponding keys.
another_engine (Engine) – Deprecated (see
global_step_transform
). Another engine to use to provide the value of event. Typically, user can provide the trainer if this handler is attached to an evaluator and thus it logs proper trainer’s epoch/iteration value.global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use
global_step_from_engine()
.
Note
Example of global_step_transform:
def global_step_transform(engine, event_name): return engine.state.get_event_attrib_value(event_name)
- class ignite.contrib.handlers.tensorboard_logger.TensorboardLogger(*args, **kwargs)[source]#
TensorBoard handler to log metrics, model/optimizer parameters, gradients during the training and validation.
By default, this class favors tensorboardX package if installed:
pip install tensorboardX
otherwise, it falls back to using PyTorch’s SummaryWriter (>=v1.2.0).
- Parameters
*args – Positional arguments accepted from
SummaryWriter
.**kwargs – Keyword arguments accepted from
SummaryWriter
, for example, log_dir to setup path to the directory where to log.
Examples
from ignite.contrib.handlers.tensorboard_logger import * # Create a logger tb_logger = TensorboardLogger(log_dir="experiments/tb_logs") # Attach the logger to the trainer to log training loss at each iteration tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", output_transform=lambda loss: {'loss': loss}), event_name=Events.ITERATION_COMPLETED) # Attach the logger to the evaluator on the training dataset and log NLL, Accuracy metrics after each epoch # We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch # of the `trainer` instead of `train_evaluator`. tb_logger.attach(train_evaluator, log_handler=OutputHandler(tag="training", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED) # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # each epoch. We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch of the # `trainer` instead of `evaluator`. tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED) # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) # Attach the logger to the trainer to log model's weights norm after each iteration tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED) # Attach the logger to the trainer to log model's weights as a histogram after each epoch tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) # Attach the logger to the trainer to log model's gradients norm after each iteration tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED) # Attach the logger to the trainer to log model's gradients as a histogram after each epoch tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED) # We need to close the logger with we are done tb_logger.close()
It is also possible to use the logger as context manager:
from ignite.contrib.handlers.tensorboard_logger import * with TensorboardLogger(log_dir="experiments/tb_logs") as tb_logger: trainer = Engine(update_fn) # Attach the logger to the trainer to log training loss at each iteration tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", output_transform=lambda loss: {'loss': loss}), event_name=Events.ITERATION_COMPLETED)
- attach(engine, log_handler, event_name)#
Attach the logger to the engine and execute log_handler function at event_name events.
- Parameters
engine (Engine) – engine object.
log_handler (callable) – a logging handler to execute
event_name – event to attach the logging handler to. Valid events are from
Events
or any event_name added byregister_events()
.
- class ignite.contrib.handlers.tensorboard_logger.WeightsHistHandler(model, tag=None)[source]#
Helper handler to log model’s weights as histograms.
Examples
from ignite.contrib.handlers.tensorboard_logger import * # Create a logger tb_logger = TensorboardLogger(log_dir="experiments/tb_logs") # Attach the logger to the trainer to log model's weights norm after each iteration tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.ITERATION_COMPLETED)
- Parameters
model (torch.nn.Module) – model to log weights
tag (str, optional) – common title for all produced plots. For example, ‘generator’
- class ignite.contrib.handlers.tensorboard_logger.WeightsScalarHandler(model, reduction=<function norm>, tag=None)[source]#
Helper handler to log model’s weights as scalars. Handler iterates over named parameters of the model, applies reduction function to each parameter produce a scalar and then logs the scalar.
Examples
from ignite.contrib.handlers.tensorboard_logger import * # Create a logger tb_logger = TensorboardLogger(log_dir="experiments/tb_logs") # Attach the logger to the trainer to log model's weights norm after each iteration tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model, reduction=torch.norm), event_name=Events.ITERATION_COMPLETED)
- Parameters
model (torch.nn.Module) – model to log weights
reduction (callable) – function to reduce parameters into scalar
tag (str, optional) – common title for all produced plots. For example, ‘generator’
- ignite.contrib.handlers.tensorboard_logger.global_step_from_engine(engine)[source]#
Helper method to setup global_step_transform function using another engine. This can be helpful for logging trainer epoch/iteration while output handler is attached to an evaluator.
- Parameters
engine (Engine) – engine which state is used to provide the global step
- Returns
global step
visdom_logger#
See visdom mnist example for detailed usage.
- class ignite.contrib.handlers.visdom_logger.GradsScalarHandler(model, reduction=<function norm>, tag=None, show_legend=False)[source]#
Helper handler to log model’s gradients as scalars. Handler iterates over the gradients of named parameters of the model, applies reduction function to each parameter produce a scalar and then logs the scalar.
Examples
from ignite.contrib.handlers.visdom_logger import * # Create a logger vd_logger = VisdomLogger() # Attach the logger to the trainer to log model's weights norm after each iteration vd_logger.attach(trainer, log_handler=GradsScalarHandler(model, reduction=torch.norm), event_name=Events.ITERATION_COMPLETED)
- Parameters
model (torch.nn.Module) – model to log weights
reduction (callable) – function to reduce parameters into scalar
tag (str, optional) – common title for all produced plots. For example, ‘generator’
show_legend (bool, optional) – flag to show legend in the window
- add_scalar(logger, k, v, event_name, global_step)#
Helper method to log a scalar with VisdomLogger.
- Parameters
logger (VisdomLogger) – visdom logger
k (str) – scalar name which is used to set window title and y-axis label
event_name – Event name which is used to setup x-axis label. Valid events are from
Events
or any event_name added byregister_events()
.global_step (int) – global step, x-axis value
- class ignite.contrib.handlers.visdom_logger.OptimizerParamsHandler(optimizer, param_name='lr', tag=None, show_legend=False)[source]#
Helper handler to log optimizer parameters
Examples
from ignite.contrib.handlers.visdom_logger import * # Create a logger vb_logger = VisdomLogger() # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration vb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
- Parameters
optimizer (torch.optim.Optimizer) – torch optimizer which parameters to log
param_name (str) – parameter name
tag (str, optional) – common title for all produced plots. For example, ‘generator’
show_legend (bool, optional) – flag to show legend in the window
- add_scalar(logger, k, v, event_name, global_step)#
Helper method to log a scalar with VisdomLogger.
- Parameters
logger (VisdomLogger) – visdom logger
k (str) – scalar name which is used to set window title and y-axis label
event_name – Event name which is used to setup x-axis label. Valid events are from
Events
or any event_name added byregister_events()
.global_step (int) – global step, x-axis value
- class ignite.contrib.handlers.visdom_logger.OutputHandler(tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None, show_legend=False)[source]#
Helper handler to log engine’s output and/or metrics
Examples
from ignite.contrib.handlers.visdom_logger import * # Create a logger vd_logger = VisdomLogger() # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # each epoch. We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch of # the `trainer`: vd_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED)
Example with CustomPeriodicEvent, where model is evaluated every 500 iterations:
from ignite.contrib.handlers import CustomPeriodicEvent cpe = CustomPeriodicEvent(n_iterations=500) cpe.attach(trainer) @trainer.on(cpe.Events.ITERATIONS_500_COMPLETED) def evaluate(engine): evaluator.run(validation_set, max_epochs=1) from ignite.contrib.handlers.visdom_logger import * vd_logger = VisdomLogger() def global_step_transform(*args, **kwargs): return trainer.state.iteration # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # every 500 iterations. Since evaluator engine does not have CustomPeriodicEvent attached to it, we # provide a global_step_transform to return the trainer.state.iteration for the global_step, each time # evaluator metrics are plotted on Visdom. vd_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metrics=["nll", "accuracy"], global_step_transform=global_step_transform), event_name=Events.EPOCH_COMPLETED)
- Parameters
tag (str) – common title for all produced plots. For example, ‘training’
metric_names (list of str, optional) – list of metric names to plot or a string “all” to plot all available metrics.
output_transform (callable, optional) – output transform function to prepare engine.state.output as a number. For example, output_transform = lambda output: output This function can also return a dictionary, e.g {‘loss’: loss1, `another_loss: loss2}` to label the plot with corresponding keys.
another_engine (Engine) – Deprecated (see
global_step_transform
). Another engine to use to provide the value of event. Typically, user can provide the trainer if this handler is attached to an evaluator and thus it logs proper trainer’s epoch/iteration value.global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use
global_step_from_engine()
.show_legend (bool, optional) – flag to show legend in the window
Note
Example of global_step_transform:
def global_step_transform(engine, event_name): return engine.state.get_event_attrib_value(event_name)
- add_scalar(logger, k, v, event_name, global_step)#
Helper method to log a scalar with VisdomLogger.
- Parameters
logger (VisdomLogger) – visdom logger
k (str) – scalar name which is used to set window title and y-axis label
event_name – Event name which is used to setup x-axis label. Valid events are from
Events
or any event_name added byregister_events()
.global_step (int) – global step, x-axis value
- class ignite.contrib.handlers.visdom_logger.VisdomLogger(server=None, port=None, num_workers=1, **kwargs)[source]#
VisdomLogger handler to log metrics, model/optimizer parameters, gradients during the training and validation.
This class requires visdom package to be installed:
pip install git+https://github.com/facebookresearch/visdom.git
- Parameters
server (str, optional) – visdom server URL. It can be also specified by environment variable VISDOM_SERVER_URL
port (int, optional) – visdom server’s port. It can be also specified by environment variable VISDOM_PORT
num_workers (int, optional) – number of workers to use in concurrent.futures.ThreadPoolExecutor to post data to visdom server. Default, num_workers=1. If num_workers=0 and logger uses the main thread. If using Python 2.7 and num_workers>0 the package futures should be installed: pip install futures
**kwargs – kwargs to pass into visdom.Visdom.
Note
We can also specify username/password using environment variables: VISDOM_USERNAME, VISDOM_PASSWORD
Warning
Frequent logging, e.g. when logger is attached to Events.ITERATION_COMPLETED, can slow down the run if the main thread is used to send the data to visdom server (num_workers=0). To avoid this situation we can either log less frequently or set num_workers=1.
Examples
from ignite.contrib.handlers.visdom_logger import * # Create a logger vd_logger = VisdomLogger() # Attach the logger to the trainer to log training loss at each iteration vd_logger.attach(trainer, log_handler=OutputHandler(tag="training", output_transform=lambda loss: {'loss': loss}), event_name=Events.ITERATION_COMPLETED) # Attach the logger to the evaluator on the training dataset and log NLL, Accuracy metrics after each epoch # We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch # of the `trainer` instead of `train_evaluator`: vd_logger.attach(train_evaluator, log_handler=OutputHandler(tag="training", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED) # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # each epoch. We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch of # the `trainer` instead of `evaluator`: vd_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer), event_name=Events.EPOCH_COMPLETED) # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration vd_logger.attach(trainer, log_handler=optimizer_params_handler(optimizer), event_name=Events.ITERATION_COMPLETED) # Attach the logger to the trainer to log model's weights norm after each iteration vd_logger.attach(trainer, log_handler=weights_scalar_handler(model), event_name=Events.ITERATION_COMPLETED) # Attach the logger to the trainer to log model's gradients norm after each iteration vd_logger.attach(trainer, log_handler=grads_scalar_handler(model), event_name=Events.ITERATION_COMPLETED) # We need to close the logger with we are done vd_logger.close()
It is also possible to use the logger as context manager:
from ignite.contrib.handlers.visdom_logger import * with VisdomLogger() as vd_logger: trainer = Engine(update_fn) # Attach the logger to the trainer to log training loss at each iteration vd_logger.attach(trainer, log_handler=OutputHandler(tag="training", output_transform=lambda loss: {'loss': loss}), event_name=Events.ITERATION_COMPLETED)
- attach(engine, log_handler, event_name)#
Attach the logger to the engine and execute log_handler function at event_name events.
- Parameters
engine (Engine) – engine object.
log_handler (callable) – a logging handler to execute
event_name – event to attach the logging handler to. Valid events are from
Events
or any event_name added byregister_events()
.
- class ignite.contrib.handlers.visdom_logger.WeightsScalarHandler(model, reduction=<function norm>, tag=None, show_legend=False)[source]#
Helper handler to log model’s weights as scalars. Handler iterates over named parameters of the model, applies reduction function to each parameter produce a scalar and then logs the scalar.
Examples
from ignite.contrib.handlers.visdom_logger import * # Create a logger vd_logger = VisdomLogger() # Attach the logger to the trainer to log model's weights norm after each iteration vd_logger.attach(trainer, log_handler=WeightsScalarHandler(model, reduction=torch.norm), event_name=Events.ITERATION_COMPLETED)
- Parameters
model (torch.nn.Module) – model to log weights
reduction (callable) – function to reduce parameters into scalar
tag (str, optional) – common title for all produced plots. For example, ‘generator’
show_legend (bool, optional) – flag to show legend in the window
- add_scalar(logger, k, v, event_name, global_step)#
Helper method to log a scalar with VisdomLogger.
- Parameters
logger (VisdomLogger) – visdom logger
k (str) – scalar name which is used to set window title and y-axis label
event_name – Event name which is used to setup x-axis label. Valid events are from
Events
or any event_name added byregister_events()
.global_step (int) – global step, x-axis value
- ignite.contrib.handlers.visdom_logger.global_step_from_engine(engine)[source]#
Helper method to setup global_step_transform function using another engine. This can be helpful for logging trainer epoch/iteration while output handler is attached to an evaluator.
- Parameters
engine (Engine) – engine which state is used to provide the global step
- Returns
global step
mlflow_logger#
- class ignite.contrib.handlers.mlflow_logger.MLflowLogger(tracking_uri=None)[source]#
MLflow tracking client handler to log parameters and metrics during the training and validation.
This class requires mlflow package to be installed:
pip install mlflow
- Parameters
tracking_uri (str) – MLflow tracking uri. See MLflow docs for more details
Examples
from ignite.contrib.handlers.mlflow_logger import * # Create a logger mlflow_logger = MLflowLogger() # Log experiment parameters: mlflow_logger.log_params(**{ "seed": seed, "batch_size": batch_size, "model": model.__class__.__name__, "pytorch version": torch.__version__, "ignite version": ignite.__version__, "cuda version": torch.version.cuda, "device name": torch.cuda.get_device_name(0) }) # Attach the logger to the evaluator on the training dataset and log NLL, Accuracy metrics after each epoch # We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch # of the `trainer` instead of `train_evaluator`. mlflow_logger.attach(train_evaluator, log_handler=OutputHandler(tag="training", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED) # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # each epoch. We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch of the # `trainer` instead of `evaluator`. mlflow_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED)
- attach(engine, log_handler, event_name)#
Attach the logger to the engine and execute log_handler function at event_name events.
- Parameters
engine (Engine) – engine object.
log_handler (callable) – a logging handler to execute
event_name – event to attach the logging handler to. Valid events are from
Events
or any event_name added byregister_events()
.
- class ignite.contrib.handlers.mlflow_logger.OptimizerParamsHandler(optimizer, param_name='lr', tag=None)[source]#
Helper handler to log optimizer parameters
Examples
from ignite.contrib.handlers.mlflow_logger import * # Create a logger mlflow_logger = MLflowLogger() # Optionally, user can specify tracking_uri with corresponds to MLFLOW_TRACKING_URI # mlflow_logger = MLflowLogger(tracking_uri="uri") # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration mlflow_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
- Parameters
optimizer (torch.optim.Optimizer) – torch optimizer which parameters to log
param_name (str) – parameter name
tag (str, optional) – common title for all produced plots. For example, ‘generator’
- class ignite.contrib.handlers.mlflow_logger.OutputHandler(tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None)[source]#
Helper handler to log engine’s output and/or metrics.
Examples
from ignite.contrib.handlers.mlflow_logger import * # Create a logger mlflow_logger = MLflowLogger() # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # each epoch. We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch # of the `trainer`: mlflow_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED)
Example with CustomPeriodicEvent, where model is evaluated every 500 iterations:
from ignite.contrib.handlers import CustomPeriodicEvent cpe = CustomPeriodicEvent(n_iterations=500) cpe.attach(trainer) @trainer.on(cpe.Events.ITERATIONS_500_COMPLETED) def evaluate(engine): evaluator.run(validation_set, max_epochs=1) from ignite.contrib.handlers.mlflow_logger import * mlflow_logger = MLflowLogger() def global_step_transform(*args, **kwargs): return trainer.state.iteration # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # every 500 iterations. Since evaluator engine does not have CustomPeriodicEvent attached to it, we # provide a global_step_transform to return the trainer.state.iteration for the global_step, each time # evaluator metrics are plotted on MLflow. mlflow_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metrics=["nll", "accuracy"], global_step_transform=global_step_transform), event_name=Events.EPOCH_COMPLETED)
- Parameters
tag (str) – common title for all produced plots. For example, ‘training’
metric_names (list of str, optional) – list of metric names to plot or a string “all” to plot all available metrics.
output_transform (callable, optional) – output transform function to prepare engine.state.output as a number. For example, output_transform = lambda output: output This function can also return a dictionary, e.g {‘loss’: loss1, `another_loss: loss2}` to label the plot with corresponding keys.
another_engine (Engine) – Deprecated (see
global_step_transform
). Another engine to use to provide the value of event. Typically, user can provide the trainer if this handler is attached to an evaluator and thus it logs proper trainer’s epoch/iteration value.global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use
global_step_from_engine()
.
Note
Example of global_step_transform:
def global_step_transform(engine, event_name): return engine.state.get_event_attrib_value(event_name)
- ignite.contrib.handlers.mlflow_logger.global_step_from_engine(engine)[source]#
Helper method to setup global_step_transform function using another engine. This can be helpful for logging trainer epoch/iteration while output handler is attached to an evaluator.
- Parameters
engine (Engine) – engine which state is used to provide the global step
- Returns
global step
tqdm_logger#
- class ignite.contrib.handlers.tqdm_logger.ProgressBar(persist=False, bar_format='{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]', **tqdm_kwargs)[source]#
TQDM progress bar handler to log training progress and computed metrics.
- Parameters
persist (bool, optional) – set to
True
to persist the progress bar after completion (default =False
)bar_format (str, optional) – Specify a custom bar string formatting. May impact performance. [default: ‘{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]’]. Set to
None
to usetqdm
default bar formatting: ‘{l_bar}{bar}{r_bar}’, where l_bar=’{desc}: {percentage:3.0f}%|’ and r_bar=’| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]’. For more details on the formatting, see tqdm docs.**tqdm_kwargs – kwargs passed to tqdm progress bar. By default, progress bar description displays “Epoch [5/10]” where 5 is the current epoch and 10 is the number of epochs. If tqdm_kwargs defines desc, e.g. “Predictions”, than the description is “Predictions [5/10]” if number of epochs is more than one otherwise it is simply “Predictions”.
Examples
Simple progress bar
trainer = create_supervised_trainer(model, optimizer, loss) pbar = ProgressBar() pbar.attach(trainer) # Progress bar will looks like # Epoch [2/50]: [64/128] 50%|█████ [06:17<12:34]
Log output to a file instead of stderr (tqdm’s default output)
trainer = create_supervised_trainer(model, optimizer, loss) log_file = open("output.log", "w") pbar = ProgressBar(file=log_file) pbar.attach(trainer)
Attach metrics that already have been computed at
ITERATION_COMPLETED
(such asRunningAverage
)trainer = create_supervised_trainer(model, optimizer, loss) RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') pbar = ProgressBar() pbar.attach(trainer, ['loss']) # Progress bar will looks like # Epoch [2/50]: [64/128] 50%|█████ , loss=0.123 [06:17<12:34]
Directly attach the engine’s output
trainer = create_supervised_trainer(model, optimizer, loss) pbar = ProgressBar() pbar.attach(trainer, output_transform=lambda x: {'loss': x}) # Progress bar will looks like # Epoch [2/50]: [64/128] 50%|█████ , loss=0.123 [06:17<12:34]
Note
When adding attaching the progress bar to an engine, it is recommend that you replace every print operation in the engine’s handlers triggered every iteration with
pbar.log_message
to guarantee the correct format of the stdout.Note
When using inside jupyter notebook, ProgressBar automatically uses tqdm_notebook. For correct rendering, please install ipywidgets. Due to tqdm notebook bugs, bar format may be needed to be set to an empty string value.
- attach(engine, metric_names=None, output_transform=None, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)[source]#
Attaches the progress bar to an engine object.
- Parameters
engine (Engine) – engine object.
metric_names (list of str, optional) – list of metric names to plot or a string “all” to plot all available metrics.
output_transform (callable, optional) – a function to select what you want to print from the engine’s output. This function may return either a dictionary with entries in the format of
{name: value}
, or a single scalar, which will be displayed with the default name output.event_name – event’s name on which the progress bar advances. Valid events are from
Events
.closing_event_name – event’s name on which the progress bar is closed. Valid events are from
Events
.
Note: accepted output value types are numbers, 0d and 1d torch tensors and strings
polyaxon_logger#
- class ignite.contrib.handlers.polyaxon_logger.OptimizerParamsHandler(optimizer, param_name='lr', tag=None)[source]#
Helper handler to log optimizer parameters
Examples
from ignite.contrib.handlers.polyaxon_logger import * # Create a logger plx_logger = PolyaxonLogger() # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration plx_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
- Parameters
optimizer (torch.optim.Optimizer) – torch optimizer which parameters to log
param_name (str) – parameter name
tag (str, optional) – common title for all produced plots. For example, ‘generator’
- class ignite.contrib.handlers.polyaxon_logger.OutputHandler(tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None)[source]#
Helper handler to log engine’s output and/or metrics.
Examples
from ignite.contrib.handlers.polyaxon_logger import * # Create a logger plx_logger = PolyaxonLogger() # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # each epoch. We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch # of the `trainer`: plx_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED)
Example with CustomPeriodicEvent, where model is evaluated every 500 iterations:
from ignite.contrib.handlers import CustomPeriodicEvent cpe = CustomPeriodicEvent(n_iterations=500) cpe.attach(trainer) @trainer.on(cpe.Events.ITERATIONS_500_COMPLETED) def evaluate(engine): evaluator.run(validation_set, max_epochs=1) from ignite.contrib.handlers.polyaxon_logger import * plx_logger = PolyaxonLogger() def global_step_transform(*args, **kwargs): return trainer.state.iteration # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # every 500 iterations. Since evaluator engine does not have CustomPeriodicEvent attached to it, we # provide a global_step_transform to return the trainer.state.iteration for the global_step, each time # evaluator metrics are plotted on Polyaxon. plx_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metrics=["nll", "accuracy"], global_step_transform=global_step_transform), event_name=Events.EPOCH_COMPLETED)
- Parameters
tag (str) – common title for all produced plots. For example, ‘training’
metric_names (list of str, optional) – list of metric names to plot or a string “all” to plot all available metrics.
output_transform (callable, optional) – output transform function to prepare engine.state.output as a number. For example, output_transform = lambda output: output This function can also return a dictionary, e.g {‘loss’: loss1, `another_loss: loss2}` to label the plot with corresponding keys.
another_engine (Engine) – Deprecated (see
global_step_transform
). Another engine to use to provide the value of event. Typically, user can provide the trainer if this handler is attached to an evaluator and thus it logs proper trainer’s epoch/iteration value.global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use
global_step_from_engine()
.
Note
Example of global_step_transform:
def global_step_transform(engine, event_name): return engine.state.get_event_attrib_value(event_name)
- class ignite.contrib.handlers.polyaxon_logger.PolyaxonLogger[source]#
Polyaxon tracking client handler to log parameters and metrics during the training and validation.
This class requires polyaxon-client package to be installed:
pip install polyaxon-client
Examples
from ignite.contrib.handlers.polyaxon_logger import * # Create a logger plx_logger = PolyaxonLogger() # Log experiment parameters: plx_logger.log_params(**{ "seed": seed, "batch_size": batch_size, "model": model.__class__.__name__, "pytorch version": torch.__version__, "ignite version": ignite.__version__, "cuda version": torch.version.cuda, "device name": torch.cuda.get_device_name(0) }) # Attach the logger to the evaluator on the training dataset and log NLL, Accuracy metrics after each epoch # We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch # of the `trainer` instead of `train_evaluator`. plx_logger.attach(train_evaluator, log_handler=OutputHandler(tag="training", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED) # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after # each epoch. We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch of the # `trainer` instead of `evaluator`. plx_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=["nll", "accuracy"], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED)
- attach(engine, log_handler, event_name)#
Attach the logger to the engine and execute log_handler function at event_name events.
- Parameters
engine (Engine) – engine object.
log_handler (callable) – a logging handler to execute
event_name – event to attach the logging handler to. Valid events are from
Events
or any event_name added byregister_events()
.
- ignite.contrib.handlers.polyaxon_logger.global_step_from_engine(engine)[source]#
Helper method to setup global_step_transform function using another engine. This can be helpful for logging trainer epoch/iteration while output handler is attached to an evaluator.
- Parameters
engine (Engine) – engine which state is used to provide the global step
- Returns
global step
More on parameter scheduling#
In this section there are visual examples of various parameter schedulings that can be achieved.
Example with ignite.contrib.handlers.CosineAnnealingScheduler
#
import numpy as np
import matplotlib.pylab as plt
from ignite.contrib.handlers import CosineAnnealingScheduler
lr_values_1 = np.array(CosineAnnealingScheduler.simulate_values(num_events=75, param_name='lr',
start_value=1e-1, end_value=2e-2, cycle_size=20))
lr_values_2 = np.array(CosineAnnealingScheduler.simulate_values(num_events=75, param_name='lr',
start_value=1e-1, end_value=2e-2, cycle_size=20, cycle_mult=1.3))
lr_values_3 = np.array(CosineAnnealingScheduler.simulate_values(num_events=75, param_name='lr',
start_value=1e-1, end_value=2e-2,
cycle_size=20, start_value_mult=0.7))
lr_values_4 = np.array(CosineAnnealingScheduler.simulate_values(num_events=75, param_name='lr',
start_value=1e-1, end_value=2e-2,
cycle_size=20, end_value_mult=0.1))
plt.figure(figsize=(25, 5))
plt.subplot(141)
plt.title("Cosine annealing")
plt.plot(lr_values_1[:, 0], lr_values_1[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.ylim([0.0, 0.12])
plt.subplot(142)
plt.title("Cosine annealing with cycle_mult=1.3")
plt.plot(lr_values_2[:, 0], lr_values_2[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.ylim([0.0, 0.12])
plt.subplot(143)
plt.title("Cosine annealing with start_value_mult=0.7")
plt.plot(lr_values_3[:, 0], lr_values_3[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.ylim([0.0, 0.12])
plt.subplot(144)
plt.title("Cosine annealing with end_value_mult=0.1")
plt.plot(lr_values_4[:, 0], lr_values_4[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.ylim([0.0, 0.12])
Example with ignite.contrib.handlers.LinearCyclicalScheduler
#
import numpy as np
import matplotlib.pylab as plt
from ignite.contrib.handlers import LinearCyclicalScheduler
lr_values_1 = np.array(LinearCyclicalScheduler.simulate_values(num_events=75, param_name='lr',
start_value=1e-1, end_value=2e-2, cycle_size=20))
lr_values_2 = np.array(LinearCyclicalScheduler.simulate_values(num_events=75, param_name='lr',
start_value=1e-1, end_value=2e-2, cycle_size=20, cycle_mult=1.3))
lr_values_3 = np.array(LinearCyclicalScheduler.simulate_values(num_events=75, param_name='lr',
start_value=1e-1, end_value=2e-2,
cycle_size=20, start_value_mult=0.7))
lr_values_4 = np.array(LinearCyclicalScheduler.simulate_values(num_events=75, param_name='lr',
start_value=1e-1, end_value=2e-2,
cycle_size=20, end_value_mult=0.1))
plt.figure(figsize=(25, 5))
plt.subplot(141)
plt.title("Linear cyclical scheduler")
plt.plot(lr_values_1[:, 0], lr_values_1[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.ylim([0.0, 0.12])
plt.subplot(142)
plt.title("Linear cyclical scheduler with cycle_mult=1.3")
plt.plot(lr_values_2[:, 0], lr_values_2[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.ylim([0.0, 0.12])
plt.subplot(143)
plt.title("Linear cyclical scheduler with start_value_mult=0.7")
plt.plot(lr_values_3[:, 0], lr_values_3[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.ylim([0.0, 0.12])
plt.subplot(144)
plt.title("Linear cyclical scheduler with end_value_mult=0.1")
plt.plot(lr_values_4[:, 0], lr_values_4[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.ylim([0.0, 0.12])
Example with ignite.contrib.handlers.ConcatScheduler
#
import numpy as np
import matplotlib.pylab as plt
from ignite.contrib.handlers import LinearCyclicalScheduler, CosineAnnealingScheduler, ConcatScheduler
import torch
t1 = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([t1], lr=0.1)
scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.1, end_value=0.5, cycle_size=30)
scheduler_2 = CosineAnnealingScheduler(optimizer, "lr", start_value=0.5, end_value=0.01, cycle_size=50)
durations = [15, ]
lr_values_1 = np.array(ConcatScheduler.simulate_values(num_events=100, schedulers=[scheduler_1, scheduler_2], durations=durations))
t1 = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([t1], lr=0.1)
scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.1, end_value=0.5, cycle_size=30)
scheduler_2 = CosineAnnealingScheduler(optimizer, "momentum", start_value=0.5, end_value=0.01, cycle_size=50)
durations = [15, ]
lr_values_2 = np.array(ConcatScheduler.simulate_values(num_events=100, schedulers=[scheduler_1, scheduler_2], durations=durations,
param_names=["lr", "momentum"]))
plt.figure(figsize=(25, 5))
plt.subplot(131)
plt.title("Concat scheduler of linear + cosine annealing")
plt.plot(lr_values_1[:, 0], lr_values_1[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.subplot(132)
plt.title("Concat scheduler of linear LR scheduler\n and cosine annealing on momentum")
plt.plot(lr_values_2[:, 0], lr_values_2[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.subplot(133)
plt.title("Concat scheduler of linear LR scheduler\n and cosine annealing on momentum")
plt.plot(lr_values_2[:, 0], lr_values_2[:, 2], label="momentum")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
Piecewise linear scheduler#
import numpy as np
import matplotlib.pylab as plt
from ignite.contrib.handlers import LinearCyclicalScheduler, ConcatScheduler
scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.0, end_value=0.6, cycle_size=50)
scheduler_2 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.6, end_value=0.0, cycle_size=150)
durations = [25, ]
lr_values = np.array(ConcatScheduler.simulate_values(num_events=100, schedulers=[scheduler_1, scheduler_2], durations=durations))
plt.title("Piecewise linear scheduler")
plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
Example with ignite.contrib.handlers.LRScheduler
#
import numpy as np
import matplotlib.pylab as plt
from ignite.contrib.handlers import LRScheduler
import torch
from torch.optim.lr_scheduler import ExponentialLR, StepLR, CosineAnnealingLR
tensor = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([tensor], lr=0.1)
lr_scheduler_1 = StepLR(optimizer=optimizer, step_size=10, gamma=0.77)
lr_scheduler_2 = ExponentialLR(optimizer=optimizer, gamma=0.98)
lr_scheduler_3 = CosineAnnealingLR(optimizer=optimizer, T_max=10, eta_min=0.01)
lr_values_1 = np.array(LRScheduler.simulate_values(num_events=100, lr_scheduler=lr_scheduler_1))
lr_values_2 = np.array(LRScheduler.simulate_values(num_events=100, lr_scheduler=lr_scheduler_2))
lr_values_3 = np.array(LRScheduler.simulate_values(num_events=100, lr_scheduler=lr_scheduler_3))
plt.figure(figsize=(25, 5))
plt.subplot(131)
plt.title("Torch LR scheduler wrapping StepLR")
plt.plot(lr_values_1[:, 0], lr_values_1[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.subplot(132)
plt.title("Torch LR scheduler wrapping ExponentialLR")
plt.plot(lr_values_2[:, 0], lr_values_2[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.subplot(133)
plt.title("Torch LR scheduler wrapping CosineAnnealingLR")
plt.plot(lr_values_3[:, 0], lr_values_3[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
Concatenate with torch schedulers#
import numpy as np
import matplotlib.pylab as plt
from ignite.contrib.handlers import LRScheduler, ConcatScheduler
import torch
from torch.optim.lr_scheduler import ExponentialLR, StepLR
t1 = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([t1], lr=0.1)
scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.001, end_value=0.1, cycle_size=30)
lr_scheduler = ExponentialLR(optimizer=optimizer, gamma=0.7)
scheduler_2 = LRScheduler(lr_scheduler=lr_scheduler)
durations = [15, ]
lr_values_1 = np.array(ConcatScheduler.simulate_values(num_events=30, schedulers=[scheduler_1, scheduler_2], durations=durations))
scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.001, end_value=0.1, cycle_size=30)
lr_scheduler = StepLR(optimizer=optimizer, step_size=10, gamma=0.7)
scheduler_2 = LRScheduler(lr_scheduler=lr_scheduler)
durations = [15, ]
lr_values_2 = np.array(ConcatScheduler.simulate_values(num_events=75, schedulers=[scheduler_1, scheduler_2], durations=durations))
plt.figure(figsize=(15, 5))
plt.subplot(121)
plt.title("Concat scheduler of linear + ExponentialLR")
plt.plot(lr_values_1[:, 0], lr_values_1[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()
plt.subplot(122)
plt.title("Concat scheduler of linear + StepLR")
plt.plot(lr_values_2[:, 0], lr_values_2[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()