Source code for ignite.utils

import sys

import torch
from torch._six import string_classes

IS_PYTHON2 = sys.version_info[0] < 3

    import collections
    import as collections

[docs]def convert_tensor(input_, device=None, non_blocking=False): """Move tensors to relevant device.""" def _func(tensor): return, non_blocking=non_blocking) if device else tensor return apply_to_tensor(input_, _func)
[docs]def apply_to_tensor(input_, func): """Apply a function on a tensor or mapping, or sequence of tensors. """ return apply_to_type(input_, torch.Tensor, func)
[docs]def apply_to_type(input_, input_type, func): """Apply a function on a object of `input_type` or mapping, or sequence of objects of `input_type`. """ if isinstance(input_, input_type): return func(input_) elif isinstance(input_, string_classes): return input_ elif isinstance(input_, collections.Mapping): return {k: apply_to_type(sample, input_type, func) for k, sample in input_.items()} elif isinstance(input_, collections.Sequence): return [apply_to_type(sample, input_type, func) for sample in input_] else: raise TypeError(("input must contain {}, dicts or lists; found {}" .format(input_type, type(input_))))
[docs]def to_onehot(indices, num_classes): """Convert a tensor of indices of any shape `(N, ...)` to a tensor of one-hot indicators of shape `(N, num_classes, ...)`. """ onehot = torch.zeros(indices.shape[0], num_classes, *indices.shape[1:], device=indices.device) return onehot.scatter_(1, indices.unsqueeze(1), 1)

© Copyright 2022, PyTorch-Ignite Contributors. Last updated on 05/04/2022, 8:31:22 PM.

Built with Sphinx using a theme provided by Read the Docs.