Shortcuts

Source code for ignite.engine

from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch

import ignite.distributed as idist
from ignite.engine.deterministic import DeterministicEngine
from ignite.engine.engine import Engine
from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, State
from ignite.metrics import Metric
from ignite.utils import convert_tensor

if idist.has_xla_support:
    import torch_xla.core.xla_model as xm


__all__ = [
    "State",
    "create_supervised_trainer",
    "create_supervised_evaluator",
    "Engine",
    "DeterministicEngine",
    "Events",
    "EventEnum",
    "CallableEventWithFilter",
]


def _prepare_batch(
    batch: Sequence[torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False
):
    """Prepare batch for training: pass to a device with options.

    """
    x, y = batch
    return (
        convert_tensor(x, device=device, non_blocking=non_blocking),
        convert_tensor(y, device=device, non_blocking=non_blocking),
    )


[docs]def create_supervised_trainer( model: torch.nn.Module, optimizer: torch.optim.Optimizer, loss_fn: Union[Callable, torch.nn.Module], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = lambda x, y, y_pred, loss: loss.item(), deterministic: bool = False, ) -> Engine: """Factory function for creating a trainer for supervised models. Args: model (`torch.nn.Module`): the model to train. optimizer (`torch.optim.Optimizer`): the optimizer to use. loss_fn (torch.nn loss function): the loss function to use. device (str, optional): device type specification (default: None). Applies to batches after starting the engine. Model *will not* be moved. Device can be CPU, GPU or TPU. non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. deterministic (bool, optional): if True, returns deterministic engine of type :class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.engine.Engine` (default: False). Note: `engine.state.output` for this engine is defined by `output_transform` parameter and is the loss of the processed batch by default. .. warning:: The internal use of `device` has changed. `device` will now *only* be used to move the input data to the correct device. The `model` should be moved by the user before creating an optimizer. For more information see: - `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_ - `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_ Returns: Engine: a trainer engine with supervised update function. """ device_type = device.type if isinstance(device, torch.device) else device on_tpu = "xla" in device_type if device_type is not None else False if on_tpu and not idist.has_xla_support: raise RuntimeError("In order to run on TPU, please install PyTorch XLA") def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]: model.train() optimizer.zero_grad() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) loss = loss_fn(y_pred, y) loss.backward() if on_tpu: xm.optimizer_step(optimizer, barrier=True) else: optimizer.step() return output_transform(x, y, y_pred, loss) trainer = Engine(_update) if not deterministic else DeterministicEngine(_update) return trainer
[docs]def create_supervised_evaluator( model: torch.nn.Module, metrics: Optional[Dict[str, Metric]] = None, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = lambda x, y, y_pred: (y_pred, y), ) -> Engine: """ Factory function for creating an evaluator for supervised models. Args: model (`torch.nn.Module`): the model to train. metrics (dict of str - :class:`~ignite.metrics.Metric`): a map of metric names to Metrics. device (str, optional): device type specification (default: None). Applies to batches after starting the engine. Model *will not* be moved. non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. output_transform (callable, optional): function that receives 'x', 'y', 'y_pred' and returns value to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits output expected by metrics. If you change it you should use `output_transform` in metrics. Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is a tuple of `(batch_pred, batch_y)` by default. .. warning:: The internal use of `device` has changed. `device` will now *only* be used to move the input data to the correct device. The `model` should be moved by the user before creating an optimizer. For more information see: - `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_ - `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_ Returns: Engine: an evaluator engine with supervised inference function. """ metrics = metrics or {} def _inference(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]: model.eval() with torch.no_grad(): x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) return output_transform(x, y, y_pred) evaluator = Engine(_inference) for name, metric in metrics.items(): metric.attach(evaluator, name) return evaluator

© Copyright 2024, PyTorch-Ignite Contributors. Last updated on 04/17/2024, 8:17:36 PM.

Built with Sphinx using a theme provided by Read the Docs.