Source code for ignite.contrib.handlers.tqdm_logger

    from tqdm import tqdm
except ImportError:
    raise RuntimeError("This contrib module requires tqdm to be installed")

from ignite.engine import Events

[docs]class ProgressBar: """ TQDM progress bar handler to log training progress and computed metrics. Args: persist (bool, optional): set to ``True`` to persist the progress bar after completion (default = ``False``) bar_format (str, optional): Specify a custom bar string formatting. May impact performance. [default: '{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]']. Set to ``None`` to use ``tqdm`` default bar formatting: '{l_bar}{bar}{r_bar}', where l_bar='{desc}: {percentage:3.0f}%|' and r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]' Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt, percentage, rate, rate_fmt, rate_noinv, rate_noinv_fmt, rate_inv, rate_inv_fmt, elapsed, remaining, desc, postfix. Note that a trailing ": " is automatically removed after {desc} if the latter is empty. **tqdm_kwargs: kwargs passed to tqdm progress bar Examples: Simple progress bar .. code-block:: python trainer = create_supervised_trainer(model, optimizer, loss) pbar = ProgressBar() pbar.attach(trainer) Attach metrics that already have been computed at `ITERATION_COMPLETED` (such as `RunningAverage`) .. code-block:: python trainer = create_supervised_trainer(model, optimizer, loss) RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') pbar = ProgressBar() pbar.attach(trainer, ['loss']) Directly attach the engine's output .. code-block:: python trainer = create_supervised_trainer(model, optimizer, loss) pbar = ProgressBar() pbar.attach(trainer, output_transform=lambda x: {'loss': x}) Note: When adding attaching the progress bar to an engine, it is recommend that you replace every print operation in the engine's handlers triggered every iteration with ``pbar.log_message`` to guarantee the correct format of the stdout. """ def __init__(self, persist=False, bar_format='{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]', **tqdm_kwargs): self.pbar = None self.persist = persist self.bar_format = bar_format self.tqdm_kwargs = tqdm_kwargs def _reset(self, engine): self.pbar = tqdm( total=len(engine.state.dataloader), leave=self.persist, bar_format=self.bar_format, **self.tqdm_kwargs ) def _close(self, engine): self.pbar.close() self.pbar = None def _update(self, engine, metric_names=None, output_transform=None): if self.pbar is None: self._reset(engine) if 'desc' not in self.tqdm_kwargs: self.pbar.set_description('Epoch [{}/{}]'.format(engine.state.epoch, engine.state.max_epochs)) metrics = {} if metric_names is not None: if not all(metric in engine.state.metrics for metric in metric_names): self._close(engine) raise KeyError("metrics not found in engine.state.metrics") metrics.update({name: '{:.2e}'.format(engine.state.metrics[name]) for name in metric_names}) if output_transform is not None: output_dict = output_transform(engine.state.output) if not isinstance(output_dict, dict): output_dict = {"output": output_dict} metrics.update({name: '{:.2e}'.format(value) for name, value in output_dict.items()}) if metrics: self.pbar.set_postfix(**metrics) self.pbar.update()
[docs] @staticmethod def log_message(message): """ Logs a message, preserving the progress bar correct output format Args: message (str): string you wish to log """ tqdm.write(message)
[docs] def attach(self, engine, metric_names=None, output_transform=None): """ Attaches the progress bar to an engine object Args: engine (Engine): engine object metric_names (list, optional): list of the metrics names to log as the bar progresses output_transform (Callable, optional): a function to select what you want to print from the engine's output. This function may return either a dictionary with entries in the format of ``{name: value}``, or a single scalar, which will be displayed with the default name `output`. """ if metric_names is not None and not isinstance(metric_names, list): raise TypeError("metric_names should be a list, got {} instead".format(type(metric_names))) if output_transform is not None and not callable(output_transform): raise TypeError("output_transform should be a function, got {} instead" .format(type(output_transform))) engine.add_event_handler(Events.ITERATION_COMPLETED, self._update, metric_names, output_transform) engine.add_event_handler(Events.EPOCH_COMPLETED, self._close)

© Copyright 2022, PyTorch-Ignite Contributors. Last updated on 05/04/2022, 8:31:16 PM.

Built with Sphinx using a theme provided by Read the Docs.