Shortcuts

MetricsLambda#

class ignite.metrics.metrics_lambda.MetricsLambda(f, *args, **kwargs)[source]#

Apply a function to other metrics to obtain a new metric. The result of the new metric is defined to be the result of applying the function to the result of argument metrics.

When update, this metric recursively updates the metrics it depends on. When reset, all its dependency metrics would be resetted as well. When attach, all its dependency metrics would be attached automatically (but partially, e.g is_attached() will return False).

Parameters
  • f (Callable) – the function that defines the computation

  • args (Any) – Sequence of other metrics or something else that will be fed to f as arguments.

  • kwargs (Any) – Sequence of other metrics or something else that will be fed to f as keyword arguments.

Examples

For more information on how metric works with Engine, visit Attach Engine API.

from collections import OrderedDict

import torch
from torch import nn, optim

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.utils import *
from ignite.contrib.metrics.regression import *
from ignite.contrib.metrics import *

# create default evaluator for doctests

def eval_step(engine, batch):
    return batch

default_evaluator = Engine(eval_step)

# create default optimizer for doctests

param_tensor = torch.zeros([1], requires_grad=True)
default_optimizer = torch.optim.SGD([param_tensor], lr=0.1)

# create default trainer for doctests
# as handlers could be attached to the trainer,
# each test must define his own trainer using `.. testsetup:`

def get_default_trainer():

    def train_step(engine, batch):
        return batch

    return Engine(train_step)

# create default model for doctests

default_model = nn.Sequential(OrderedDict([
    ('base', nn.Linear(4, 2)),
    ('fc', nn.Linear(2, 1))
]))

manual_seed(666)
precision = Precision(average=False)
recall = Recall(average=False)

def Fbeta(r, p, beta):
    return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r + 1e-20)).item()

F1 = MetricsLambda(Fbeta, recall, precision, 1)
F2 = MetricsLambda(Fbeta, recall, precision, 2)
F3 = MetricsLambda(Fbeta, recall, precision, 3)
F4 = MetricsLambda(Fbeta, recall, precision, 4)

F1.attach(default_evaluator, "F1")
F2.attach(default_evaluator, "F2")
F3.attach(default_evaluator, "F3")
F4.attach(default_evaluator, "F4")

y_true = torch.tensor([1, 0, 1, 0, 0, 1])
y_pred = torch.tensor([1, 0, 1, 0, 1, 1])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["F1"])
print(state.metrics["F2"])
print(state.metrics["F3"])
print(state.metrics["F4"])
0.8571...
0.9375...
0.9677...
0.9807...

When check if the metric is attached, if one of its dependency metrics is detached, the metric is considered detached too.

engine = ...
precision = Precision(average=False)

aP = precision.mean()

aP.attach(engine, "aP")

assert aP.is_attached(engine)
# partially attached
assert not precision.is_attached(engine)

precision.detach(engine)

assert not aP.is_attached(engine)
# fully attached
assert not precision.is_attached(engine)

Methods

attach

Attaches current metric to provided engine.

compute

Computes the metric based on it's accumulated state.

detach

Detaches current metric from the engine and no metric's computation is done during the run.

is_attached

Checks if current metric is attached to provided engine.

reset

Resets the metric to it's initial state.

update

Updates the metric's state using the passed batch output.

attach(engine, name, usage=<ignite.metrics.metric.EpochWise object>)[source]#

Attaches current metric to provided engine. On the end of engine’s run, engine.state.metrics dictionary will contain computed metric’s value under provided name.

Parameters
Return type

None

Examples

metric = ...
metric.attach(engine, "mymetric")

assert "mymetric" in engine.run(data).metrics

assert metric.is_attached(engine)

Example with usage:

metric = ...
metric.attach(engine, "mymetric", usage=BatchWise.usage_name)

assert "mymetric" in engine.run(data).metrics

assert metric.is_attached(engine, usage=BatchWise.usage_name)
compute()[source]#

Computes the metric based on it’s accumulated state.

By default, this is called at the end of each epoch.

Returns

the actual quantity of interest. However, if a Mapping is returned, it will be (shallow) flattened into engine.state.metrics when completed() is called.

Return type

Any

Raises

NotComputableError – raised when the metric cannot be computed.

detach(engine, usage=<ignite.metrics.metric.EpochWise object>)[source]#

Detaches current metric from the engine and no metric’s computation is done during the run. This method in conjunction with attach() can be useful if several metrics need to be computed with different periods. For example, one metric is computed every training epoch and another metric (e.g. more expensive one) is done every n-th training epoch.

Parameters
  • engine (Engine) – the engine from which the metric must be detached

  • usage (Union[str, MetricUsage]) – the usage of the metric. Valid string values should be ‘epoch_wise’ (default) or ‘batch_wise’.

Return type

None

Examples

metric = ...
engine = ...
metric.detach(engine)

assert "mymetric" not in engine.run(data).metrics

assert not metric.is_attached(engine)

Example with usage:

metric = ...
engine = ...
metric.detach(engine, usage="batch_wise")

assert "mymetric" not in engine.run(data).metrics

assert not metric.is_attached(engine, usage="batch_wise")
is_attached(engine, usage=<ignite.metrics.metric.EpochWise object>)[source]#

Checks if current metric is attached to provided engine. If attached, metric’s computed value is written to engine.state.metrics dictionary.

Parameters
  • engine (Engine) – the engine checked from which the metric should be attached

  • usage (Union[str, MetricUsage]) – the usage of the metric. Valid string values should be ‘epoch_wise’ (default) or ‘batch_wise’.

Return type

bool

reset()[source]#

Resets the metric to it’s initial state.

By default, this is called at the start of each epoch.

Return type

None

update(output)[source]#

Updates the metric’s state using the passed batch output.

By default, this is called once for each batch.

Parameters

output (Any) – the is the output from the engine’s process function.

Return type

None