import collections.abc as collections
import functools
import logging
import random
import sys
import warnings
from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Type, TypeVar, Union, cast
import torch
__all__ = ["convert_tensor", "apply_to_tensor", "apply_to_type", "to_onehot", "setup_logger", "manual_seed"]
[docs]def convert_tensor(
input_: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
) -> Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes]:
"""Move tensors to relevant device."""
def _func(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(device=device, non_blocking=non_blocking) if device is not None else tensor
return apply_to_tensor(input_, _func)
[docs]def apply_to_tensor(
input_: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], func: Callable
) -> Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes]:
"""Apply a function on a tensor or mapping, or sequence of tensors.
"""
return apply_to_type(input_, torch.Tensor, func)
[docs]def apply_to_type(
input_: Union[Any, collections.Sequence, collections.Mapping, str, bytes],
input_type: Union[Type, Tuple[Type[Any], Any]],
func: Callable,
) -> Union[Any, collections.Sequence, collections.Mapping, str, bytes]:
"""Apply a function on a object of `input_type` or mapping, or sequence of objects of `input_type`.
"""
if isinstance(input_, input_type):
return func(input_)
if isinstance(input_, (str, bytes)):
return input_
if isinstance(input_, collections.Mapping):
return cast(Callable, type(input_))(
{k: apply_to_type(sample, input_type, func) for k, sample in input_.items()}
)
if isinstance(input_, tuple) and hasattr(input_, "_fields"): # namedtuple
return cast(Callable, type(input_))(*(apply_to_type(sample, input_type, func) for sample in input_))
if isinstance(input_, collections.Sequence):
return cast(Callable, type(input_))([apply_to_type(sample, input_type, func) for sample in input_])
raise TypeError((f"input must contain {input_type}, dicts or lists; found {type(input_)}"))
[docs]def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor:
"""Convert a tensor of indices of any shape `(N, ...)` to a
tensor of one-hot indicators of shape `(N, num_classes, ...) and of type uint8. Output's device is equal to the
input's device`.
.. versionchanged:: 0.4.3
This functions is now torchscriptable.
"""
new_shape = (indices.shape[0], num_classes) + indices.shape[1:]
onehot = torch.zeros(new_shape, dtype=torch.uint8, device=indices.device)
return onehot.scatter_(1, indices.unsqueeze(1), 1)
[docs]def setup_logger(
name: Optional[str] = None,
level: int = logging.INFO,
stream: Optional[TextIO] = None,
format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
filepath: Optional[str] = None,
distributed_rank: Optional[int] = None,
) -> logging.Logger:
"""Setups logger: name, level, format etc.
Args:
name (str, optional): new name for the logger. If None, the standard logger is used.
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG.
stream (TextIO, optional): logging stream. If None, the standard stream is used (sys.stderr).
format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`.
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
If None, distributed_rank is initialized to the rank of process.
Returns:
logging.Logger
For example, to improve logs readability when training with a trainer and evaluator:
.. code-block:: python
from ignite.utils import setup_logger
trainer = ...
evaluator = ...
trainer.logger = setup_logger("trainer")
evaluator.logger = setup_logger("evaluator")
trainer.run(data, max_epochs=10)
# Logs will look like
# 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5.
# 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23
# 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1.
# 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02
# ...
.. versionchanged:: 0.4.3
Added ``stream`` parameter.
"""
logger = logging.getLogger(name)
# don't propagate to ancestors
# the problem here is to attach handlers to loggers
# should we provide a default configuration less open ?
if name is not None:
logger.propagate = False
# Remove previous handlers
if logger.hasHandlers():
for h in list(logger.handlers):
logger.removeHandler(h)
formatter = logging.Formatter(format)
if distributed_rank is None:
import ignite.distributed as idist
distributed_rank = idist.get_rank()
if distributed_rank > 0:
logger.addHandler(logging.NullHandler())
else:
logger.setLevel(level)
ch = logging.StreamHandler(stream=stream)
ch.setLevel(level)
ch.setFormatter(formatter)
logger.addHandler(ch)
if filepath is not None:
fh = logging.FileHandler(filepath)
fh.setLevel(level)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
[docs]def manual_seed(seed: int) -> None:
"""Setup random state from a seed for `torch`, `random` and optionally `numpy` (if can be imported).
Args:
seed (int): Random state seed
.. versionchanged:: 0.4.3
Added ``torch.cuda.manual_seed_all(seed)``.
"""
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
try:
import numpy as np
np.random.seed(seed)
except ImportError:
pass
def deprecated(
deprecated_in: str, removed_in: str = "", reasons: Tuple[str, ...] = (), raise_exception: bool = False
) -> Callable:
F = TypeVar("F", bound=Callable[..., Any])
def decorator(func: F) -> F:
func_doc = func.__doc__ if func.__doc__ else ""
deprecation_warning = (
f"This function has been deprecated since version {deprecated_in}"
+ (f" and will be removed in version {removed_in}" if removed_in else "")
+ ".\n Please refer to the documentation for more details."
)
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> Callable:
if raise_exception:
raise DeprecationWarning(deprecation_warning)
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
return func(*args, **kwargs)
appended_doc = f".. deprecated:: {deprecated_in}" + ("\n\n\t" if len(reasons) else "")
for reason in reasons:
appended_doc += "\n\t- " + reason
wrapper.__doc__ = f"**Deprecated function**.\n\n {func_doc}{appended_doc}"
return cast(F, wrapper)
return decorator