Shortcuts

Source code for ignite.engine

import torch

from ignite.engine.engine import Engine, State, Events
from ignite.utils import convert_tensor


def _prepare_batch(batch, device=None, non_blocking=False):
    """Prepare batch for training: pass to a device with options.

    """
    x, y = batch
    return (convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking))


[docs]def create_supervised_trainer(model, optimizer, loss_fn, device=None, non_blocking=False, prepare_batch=_prepare_batch, output_transform=lambda x, y, y_pred, loss: loss.item()): """ Factory function for creating a trainer for supervised models. Args: model (`torch.nn.Module`): the model to train. optimizer (`torch.optim.Optimizer`): the optimizer to use. loss_fn (torch.nn loss function): the loss function to use. device (str, optional): device type specification (default: None). Applies to both model and batches. non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss of the processed batch by default. Returns: Engine: a trainer engine with supervised update function. """ if device: model.to(device) def _update(engine, batch): model.train() optimizer.zero_grad() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) loss = loss_fn(y_pred, y) loss.backward() optimizer.step() return output_transform(x, y, y_pred, loss) return Engine(_update)
[docs]def create_supervised_evaluator(model, metrics=None, device=None, non_blocking=False, prepare_batch=_prepare_batch, output_transform=lambda x, y, y_pred: (y_pred, y,)): """ Factory function for creating an evaluator for supervised models. Args: model (`torch.nn.Module`): the model to train. metrics (dict of str - :class:`~ignite.metrics.Metric`): a map of metric names to Metrics. device (str, optional): device type specification (default: None). Applies to both model and batches. non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. output_transform (callable, optional): function that receives 'x', 'y', 'y_pred' and returns value to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits output expected by metrics. If you change it you should use `output_transform` in metrics. Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is a tuple of `(batch_pred, batch_y)` by default. Returns: Engine: an evaluator engine with supervised inference function. """ metrics = metrics or {} if device: model.to(device) def _inference(engine, batch): model.eval() with torch.no_grad(): x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) return output_transform(x, y, y_pred) engine = Engine(_inference) for name, metric in metrics.items(): metric.attach(engine, name) return engine

© Copyright 2024, PyTorch-Ignite Contributors. Last updated on 05/08/2024, 8:23:57 AM.

Built with Sphinx using a theme provided by Read the Docs.