[docs]defconvert_tensor(x:Union[torch.Tensor,collections.Sequence,collections.Mapping,str,bytes],device:Optional[Union[str,torch.device]]=None,non_blocking:bool=False,)->Union[torch.Tensor,collections.Sequence,collections.Mapping,str,bytes]:"""Move tensors to relevant device. Args: x: input tensor or mapping, or sequence of tensors. device: device type to move ``x``. non_blocking: convert a CPU Tensor with pinned memory to a CUDA Tensor asynchronously with respect to the host if possible """def_func(tensor:torch.Tensor)->torch.Tensor:returntensor.to(device=device,non_blocking=non_blocking)ifdeviceisnotNoneelsetensorreturnapply_to_tensor(x,_func)
[docs]defapply_to_tensor(x:Union[torch.Tensor,collections.Sequence,collections.Mapping,str,bytes],func:Callable)->Union[torch.Tensor,collections.Sequence,collections.Mapping,str,bytes]:"""Apply a function on a tensor or mapping, or sequence of tensors. Args: x: input tensor or mapping, or sequence of tensors. func: the function to apply on ``x``. """returnapply_to_type(x,torch.Tensor,func)
[docs]defapply_to_type(x:Union[Any,collections.Sequence,collections.Mapping,str,bytes],input_type:Union[Type,Tuple[Type[Any],Any]],func:Callable,)->Union[Any,collections.Sequence,collections.Mapping,str,bytes]:"""Apply a function on an object of `input_type` or mapping, or sequence of objects of `input_type`. Args: x: object or mapping or sequence. input_type: data type of ``x``. func: the function to apply on ``x``. """ifisinstance(x,input_type):returnfunc(x)ifisinstance(x,(str,bytes)):returnxifisinstance(x,collections.Mapping):returncast(Callable,type(x))({k:apply_to_type(sample,input_type,func)fork,sampleinx.items()})ifisinstance(x,tuple)andhasattr(x,"_fields"):# namedtuplereturncast(Callable,type(x))(*(apply_to_type(sample,input_type,func)forsampleinx))ifisinstance(x,collections.Sequence):returncast(Callable,type(x))([apply_to_type(sample,input_type,func)forsampleinx])raiseTypeError((f"x must contain {input_type}, dicts or lists; found {type(x)}"))
def_tree_map(func:Callable,x:Union[Any,collections.Sequence,collections.Mapping],key:Optional[Union[int,str]]=None)->Union[Any,collections.Sequence,collections.Mapping]:ifisinstance(x,collections.Mapping):returncast(Callable,type(x))({k:_tree_map(func,sample,key=k)fork,sampleinx.items()})ifisinstance(x,tuple)andhasattr(x,"_fields"):# namedtuplereturncast(Callable,type(x))(*(_tree_map(func,sample)forsampleinx))ifisinstance(x,collections.Sequence):returncast(Callable,type(x))([_tree_map(func,sample,key=i)fori,sampleinenumerate(x)])returnfunc(x,key=key)class_CollectionItem:types_as_collection_item:Tuple=(int,float,torch.Tensor)def__init__(self,collection:Union[Dict,List],key:Union[int,str])->None:ifnotisinstance(collection,(dict,list)):raiseTypeError(f"Input type is expected to be a mapping or list, but got {type(collection)} "f"for input key '{key}'.")ifisinstance(collection,list)andisinstance(key,str):raiseValueError("Key should be int for collection of type list")self.collection=collectionself.key=keydefload_value(self,value:Any)->None:self.collection[self.key]=value# type: ignore[index]defvalue(self)->Any:returnself.collection[self.key]# type: ignore[index]@staticmethoddefwrap(object:Union[Dict,List],key:Union[int,str],value:Any)->Union[Any,"_CollectionItem"]:return(_CollectionItem(object,key)ifvalueisNoneorisinstance(value,_CollectionItem.types_as_collection_item)elsevalue)def_tree_apply2(func:Callable,x:Union[Any,List,Dict],y:Union[Any,collections.Sequence,collections.Mapping],)->None:ifisinstance(x,dict)andisinstance(y,collections.Mapping):fork,vinx.items():ifknotiny:raiseValueError(f"Key '{k}' from x is not found in y: {y.keys()}")_tree_apply2(func,_CollectionItem.wrap(x,k,v),y[k])elifisinstance(x,list)andisinstance(y,collections.Sequence):iflen(x)!=len(y):raiseValueError(f"Size of y: {len(y)} does not match the size of x: '{len(x)}'")fori,(v1,v2)inenumerate(zip(x,y)):_tree_apply2(func,_CollectionItem.wrap(x,i,v1),v2)else:returnfunc(x,y)
[docs]defto_onehot(indices:torch.Tensor,num_classes:int)->torch.Tensor:"""Convert a tensor of indices of any shape `(N, ...)` to a tensor of one-hot indicators of shape `(N, num_classes, ...)` and of type uint8. Output's device is equal to the input's device`. Args: indices: input tensor to convert. num_classes: number of classes for one-hot tensor. .. versionchanged:: 0.4.3 This functions is now torchscriptable. """new_shape=(indices.shape[0],num_classes)+indices.shape[1:]onehot=torch.zeros(new_shape,dtype=torch.uint8,device=indices.device)returnonehot.scatter_(1,indices.unsqueeze(1),1)
[docs]defsetup_logger(name:Optional[str]="ignite",level:int=logging.INFO,stream:Optional[TextIO]=None,format:str="%(asctime)s%(name)s%(levelname)s: %(message)s",filepath:Optional[str]=None,distributed_rank:Optional[int]=None,reset:bool=False,)->logging.Logger:"""Setups logger: name, level, format etc. Args: name: new name for the logger. If None, the standard logger is used. level: logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG. stream: logging stream. If None, the standard stream is used (sys.stderr). format: logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`. filepath: Optional logging file path. If not None, logs are written to the file. distributed_rank: Optional, rank in distributed configuration to avoid logger setup for workers. If None, distributed_rank is initialized to the rank of process. reset: if True, reset an existing logger rather than keep format, handlers, and level. Returns: logging.Logger Examples: Improve logs readability when training with a trainer and evaluator: .. code-block:: python from ignite.utils import setup_logger trainer = ... evaluator = ... trainer.logger = setup_logger("trainer") evaluator.logger = setup_logger("evaluator") trainer.run(data, max_epochs=10) # Logs will look like # 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5. # 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23 # 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1. # 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02 # ... Every existing logger can be reset if needed .. code-block:: python logger = setup_logger(name="my-logger", format="=== %(name)s %(message)s") logger.info("first message") setup_logger(name="my-logger", format="+++ %(name)s %(message)s", reset=True) logger.info("second message") # Logs will look like # === my-logger first message # +++ my-logger second message Change the level of an existing internal logger .. code-block:: python setup_logger( name="ignite.distributed.launcher.Parallel", level=logging.WARNING ) .. versionchanged:: 0.4.3 Added ``stream`` parameter. .. versionchanged:: 0.4.5 Added ``reset`` parameter. """# check if the logger already existsexisting=nameisNoneornameinlogging.root.manager.loggerDict# if existing, get the logger otherwise create a new onelogger=logging.getLogger(name)ifdistributed_rankisNone:importignite.distributedasidistdistributed_rank=idist.get_rank()# Remove previous handlersifdistributed_rank>0orreset:iflogger.hasHandlers():forhinlist(logger.handlers):logger.removeHandler(h)ifdistributed_rank>0:# Add null handler to avoid multiple parallel messageslogger.addHandler(logging.NullHandler())# Keep the existing configuration if not resetifexistingandnotreset:returnloggerifdistributed_rank==0:logger.setLevel(level)formatter=logging.Formatter(format)ch=logging.StreamHandler(stream=stream)ch.setLevel(level)ch.setFormatter(formatter)logger.addHandler(ch)iffilepathisnotNone:fh=logging.FileHandler(filepath)fh.setLevel(level)fh.setFormatter(formatter)logger.addHandler(fh)# don't propagate to ancestors# the problem here is to attach handlers to loggers# should we provide a default configuration less open ?ifnameisnotNone:logger.propagate=Falsereturnlogger
[docs]defmanual_seed(seed:int)->None:"""Setup random state from a seed for `torch`, `random` and optionally `numpy` (if can be imported). Args: seed: Random state seed .. versionchanged:: 0.4.3 Added ``torch.cuda.manual_seed_all(seed)``. .. versionchanged:: 0.4.5 Added ``torch_xla.core.xla_model.set_rng_state(seed)``. """random.seed(seed)torch.manual_seed(seed)try:importtorch_xla.core.xla_modelasxmxm.set_rng_state(seed)exceptImportError:passtry:importnumpyasnpnp.random.seed(seed)exceptImportError:pass
defdeprecated(deprecated_in:str,removed_in:str="",reasons:Tuple[str,...]=(),raise_exception:bool=False)->Callable:F=TypeVar("F",bound=Callable[...,Any])defdecorator(func:F)->F:func_doc=func.__doc__iffunc.__doc__else""deprecation_warning=(f"This function has been deprecated since version {deprecated_in}"+(f" and will be removed in version {removed_in}"ifremoved_inelse"")+".\n Please refer to the documentation for more details.")@functools.wraps(func)defwrapper(*args:Any,**kwargs:Dict[str,Any])->Callable:ifraise_exception:raiseDeprecationWarning(deprecation_warning)warnings.warn(deprecation_warning,DeprecationWarning,stacklevel=2)returnfunc(*args,**kwargs)appended_doc=f".. deprecated:: {deprecated_in}"+("\n\n\t"iflen(reasons)>0else"")forreasoninreasons:appended_doc+="\n\t- "+reasonwrapper.__doc__=f"**Deprecated function**.\n\n{func_doc}{appended_doc}"returncast(F,wrapper)returndecorator
[docs]defhash_checkpoint(checkpoint_path:Union[str,Path],output_dir:Union[str,Path])->Tuple[Path,str]:""" Hash the checkpoint file in the format of ``<filename>-<hash>.<ext>`` to be used with ``check_hash`` of :func:`torch.hub.load_state_dict_from_url`. Args: checkpoint_path: Path to the checkpoint file. output_dir: Output directory to store the hashed checkpoint file (will be created if not exist). Returns: Path to the hashed checkpoint file, the first 8 digits of SHA256 hash. .. versionadded:: 0.4.8 """ifisinstance(checkpoint_path,str):checkpoint_path=Path(checkpoint_path)ifnotcheckpoint_path.exists():raiseFileNotFoundError(f"{checkpoint_path.name} does not exist in {checkpoint_path.parent}.")ifisinstance(output_dir,str):output_dir=Path(output_dir)output_dir.mkdir(parents=True,exist_ok=True)hash_obj=hashlib.sha256()# taken from https://github.com/pytorch/vision/blob/main/references/classification/utils.pywithcheckpoint_path.open("rb")asf:# Read and update hash string value in blocks of 4KBforbyte_blockiniter(lambda:f.read(4096),b""):hash_obj.update(byte_block)sha_hash=hash_obj.hexdigest()old_filename=checkpoint_path.stemnew_filename="-".join((old_filename,sha_hash[:8]))+".pt"hash_checkpoint_path=output_dir/new_filenameshutil.move(str(checkpoint_path),hash_checkpoint_path)returnhash_checkpoint_path,sha_hash