[docs]defconvert_tensor(x:Union[torch.Tensor,collections.Sequence,collections.Mapping,str,bytes],device:Optional[Union[str,torch.device]]=None,non_blocking:bool=False,)->Union[torch.Tensor,collections.Sequence,collections.Mapping,str,bytes]:"""Move tensors to relevant device. Args: x: input tensor or mapping, or sequence of tensors. device: device type to move ``x``. non_blocking: convert a CPU Tensor with pinned memory to a CUDA Tensor asynchronously with respect to the host if possible """def_func(tensor:torch.Tensor)->torch.Tensor:returntensor.to(device=device,non_blocking=non_blocking)ifdeviceisnotNoneelsetensorreturnapply_to_tensor(x,_func)
[docs]defapply_to_tensor(x:Union[torch.Tensor,collections.Sequence,collections.Mapping,str,bytes],func:Callable)->Union[torch.Tensor,collections.Sequence,collections.Mapping,str,bytes]:"""Apply a function on a tensor or mapping, or sequence of tensors. Args: x: input tensor or mapping, or sequence of tensors. func: the function to apply on ``x``. """returnapply_to_type(x,torch.Tensor,func)
[docs]defapply_to_type(x:Union[Any,collections.Sequence,collections.Mapping,str,bytes],input_type:Union[Type,Tuple[Type[Any],Any]],func:Callable,)->Union[Any,collections.Sequence,collections.Mapping,str,bytes]:"""Apply a function on an object of `input_type` or mapping, or sequence of objects of `input_type`. Args: x: object or mapping or sequence. input_type: data type of ``x``. func: the function to apply on ``x``. """ifisinstance(x,input_type):returnfunc(x)ifisinstance(x,(str,bytes)):returnxifisinstance(x,collections.Mapping):returncast(Callable,type(x))({k:apply_to_type(sample,input_type,func)fork,sampleinx.items()})ifisinstance(x,tuple)andhasattr(x,"_fields"):# namedtuplereturncast(Callable,type(x))(*(apply_to_type(sample,input_type,func)forsampleinx))ifisinstance(x,collections.Sequence):returncast(Callable,type(x))([apply_to_type(sample,input_type,func)forsampleinx])raiseTypeError((f"x must contain {input_type}, dicts or lists; found {type(x)}"))
def_tree_map(func:Callable,x:Union[Any,collections.Sequence,collections.Mapping],key:Optional[Union[int,str]]=None)->Union[Any,collections.Sequence,collections.Mapping]:ifisinstance(x,collections.Mapping):returncast(Callable,type(x))({k:_tree_map(func,sample,key=k)fork,sampleinx.items()})ifisinstance(x,tuple)andhasattr(x,"_fields"):# namedtuplereturncast(Callable,type(x))(*(_tree_map(func,sample)forsampleinx))ifisinstance(x,collections.Sequence):returncast(Callable,type(x))([_tree_map(func,sample,key=i)fori,sampleinenumerate(x)])returnfunc(x,key=key)def_to_str_list(data:Any)->List[str]:""" Recursively flattens and formats complex data structures, including keys for dictionaries, into a list of human-readable strings. This function processes nested dictionaries, lists, tuples, numbers, and PyTorch tensors, formatting numbers to four decimal places and handling tensors with special formatting rules. It's particularly useful for logging, debugging, or any scenario where a human-readable representation of complex, nested data structures is required. The function handles the following types: - Numbers: Formatted to four decimal places. - PyTorch tensors: - Scalars are formatted to four decimal places. - 1D tensors with more than 10 elements show the first 10 elements followed by an ellipsis. - 1D tensors with 10 or fewer elements are fully listed. - Multi-dimensional tensors display their shape. - Dictionaries: Each key-value pair is included in the output with the key as a prefix. - Lists and tuples: Flattened and included in the output. Empty lists/tuples are represented by an empty string. - None values: Represented by an empty string. Args: data: The input data to be flattened and formatted. It can be a nested combination of dictionaries, lists, tuples, numbers, and PyTorch tensors. Returns: A list of formatted strings, each representing a part of the input data structure. """formatted_items:List[str]=[]defformat_item(item:Any,prefix:str="")->Optional[str]:ifisinstance(item,numbers.Number):returnf"{prefix}{item:.4f}"eliftorch.is_tensor(item):ifitem.dim()==0:returnf"{prefix}{item.item():.4f}"# Format scalar tensor without bracketselifitem.dim()==1anditem.size(0)>10:returnf"{prefix}["+", ".join(f"{x.item():.4f}"forxinitem[:10])+", ...]"elifitem.dim()==1:returnf"{prefix}["+", ".join(f"{x.item():.4f}"forxinitem)+"]"else:returnf"{prefix}Shape{list(item.shape)}"elifisinstance(item,dict):forkey,valueinitem.items():formatted_value=format_item(value,f"{key}: ")ifformatted_valueisnotNone:formatted_items.append(formatted_value)elifisinstance(item,(list,tuple)):ifnotitem:ifprefix:formatted_items.append(f"{prefix}")else:values=[format_item(x)forxinitem]values_str=[vforvinvaluesifvisnotNone]ifvalues_str:formatted_items.append(f"{prefix}"+", ".join(values_str))elifitemisNone:ifprefix:formatted_items.append(f"{prefix}")returnNone# Directly handle single numeric valuesifisinstance(data,numbers.Number):return[f"{data:.4f}"]format_item(data)returnformatted_itemsclass_CollectionItem:types_as_collection_item:Tuple=(int,float,torch.Tensor)def__init__(self,collection:Union[Dict,List],key:Union[int,str])->None:ifnotisinstance(collection,(dict,list)):raiseTypeError(f"Input type is expected to be a mapping or list, but got {type(collection)} "f"for input key '{key}'.")ifisinstance(collection,list)andisinstance(key,str):raiseValueError("Key should be int for collection of type list")self.collection=collectionself.key=keydefload_value(self,value:Any)->None:self.collection[self.key]=value# type: ignore[index]defvalue(self)->Any:returnself.collection[self.key]# type: ignore[index]@staticmethoddefwrap(object:Union[Dict,List],key:Union[int,str],value:Any)->Union[Any,"_CollectionItem"]:return(_CollectionItem(object,key)ifvalueisNoneorisinstance(value,_CollectionItem.types_as_collection_item)elsevalue)def_tree_apply2(func:Callable,x:Union[Any,List,Dict],y:Union[Any,collections.Sequence,collections.Mapping],)->None:ifisinstance(x,dict)andisinstance(y,collections.Mapping):fork,vinx.items():ifknotiny:raiseValueError(f"Key '{k}' from x is not found in y: {y.keys()}")_tree_apply2(func,_CollectionItem.wrap(x,k,v),y[k])elifisinstance(x,list)andisinstance(y,collections.Sequence):iflen(x)!=len(y):raiseValueError(f"Size of y: {len(y)} does not match the size of x: '{len(x)}'")fori,(v1,v2)inenumerate(zip(x,y)):_tree_apply2(func,_CollectionItem.wrap(x,i,v1),v2)else:returnfunc(x,y)
[docs]defto_onehot(indices:torch.Tensor,num_classes:int)->torch.Tensor:"""Convert a tensor of indices of any shape `(N, ...)` to a tensor of one-hot indicators of shape `(N, num_classes, ...)` and of type uint8. Output's device is equal to the input's device`. Args: indices: input tensor to convert. num_classes: number of classes for one-hot tensor. .. versionchanged:: 0.4.3 This functions is now torchscriptable. """new_shape=(indices.shape[0],num_classes)+indices.shape[1:]onehot=torch.zeros(new_shape,dtype=torch.uint8,device=indices.device)returnonehot.scatter_(1,indices.unsqueeze(1),1)
[docs]defsetup_logger(name:Optional[str]="ignite",level:int=logging.INFO,stream:Optional[TextIO]=None,format:str="%(asctime)s%(name)s%(levelname)s: %(message)s",filepath:Optional[str]=None,distributed_rank:Optional[int]=None,reset:bool=False,encoding:Optional[str]="utf-8",)->logging.Logger:"""Setups logger: name, level, format etc. Args: name: new name for the logger. If None, the standard logger is used. level: logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG. stream: logging stream. If None, the standard stream is used (sys.stderr). format: logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`. filepath: Optional logging file path. If not None, logs are written to the file. distributed_rank: Optional, rank in distributed configuration to avoid logger setup for workers. If None, distributed_rank is initialized to the rank of process. reset: if True, reset an existing logger rather than keep format, handlers, and level. encoding: open the file with the encoding. By default, 'utf-8'. Returns: logging.Logger Examples: Improve logs readability when training with a trainer and evaluator: .. code-block:: python from ignite.utils import setup_logger trainer = ... evaluator = ... trainer.logger = setup_logger("trainer") evaluator.logger = setup_logger("evaluator") trainer.run(data, max_epochs=10) # Logs will look like # 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5. # 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23 # 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1. # 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02 # ... Every existing logger can be reset if needed .. code-block:: python logger = setup_logger(name="my-logger", format="=== %(name)s %(message)s") logger.info("first message") setup_logger(name="my-logger", format="+++ %(name)s %(message)s", reset=True) logger.info("second message") # Logs will look like # === my-logger first message # +++ my-logger second message Change the level of an existing internal logger .. code-block:: python setup_logger( name="ignite.distributed.launcher.Parallel", level=logging.WARNING ) .. versionchanged:: 0.4.3 Added ``stream`` parameter. .. versionchanged:: 0.4.5 Added ``reset`` parameter. .. versionchanged:: 0.5.1 Argument ``encoding`` added to correctly handle special characters in the file, default "utf-8". """# check if the logger already existsexisting=nameisNoneornameinlogging.root.manager.loggerDict# if existing, get the logger otherwise create a new onelogger=logging.getLogger(name)ifdistributed_rankisNone:importignite.distributedasidistdistributed_rank=idist.get_rank()# Remove previous handlersifdistributed_rank>0orreset:iflogger.hasHandlers():forhinlist(logger.handlers):logger.removeHandler(h)ifdistributed_rank>0:# Add null handler to avoid multiple parallel messageslogger.addHandler(logging.NullHandler())# Keep the existing configuration if not resetifexistingandnotreset:returnloggerifdistributed_rank==0:logger.setLevel(level)formatter=logging.Formatter(format)ch=logging.StreamHandler(stream=stream)ch.setLevel(level)ch.setFormatter(formatter)logger.addHandler(ch)iffilepathisnotNone:fh=logging.FileHandler(filepath,encoding=encoding)fh.setLevel(level)fh.setFormatter(formatter)logger.addHandler(fh)# don't propagate to ancestors# the problem here is to attach handlers to loggers# should we provide a default configuration less open ?ifnameisnotNone:logger.propagate=Falsereturnlogger
[docs]defmanual_seed(seed:int)->None:"""Setup random state from a seed for `torch`, `random` and optionally `numpy` (if can be imported). Args: seed: Random state seed .. versionchanged:: 0.4.3 Added ``torch.cuda.manual_seed_all(seed)``. .. versionchanged:: 0.4.5 Added ``torch_xla.core.xla_model.set_rng_state(seed)``. """random.seed(seed)torch.manual_seed(seed)try:importtorch_xla.core.xla_modelasxmxm.set_rng_state(seed)exceptImportError:passtry:importnumpyasnpnp.random.seed(seed)exceptImportError:pass
defdeprecated(deprecated_in:str,removed_in:str="",reasons:Tuple[str,...]=(),raise_exception:bool=False)->Callable:F=TypeVar("F",bound=Callable[...,Any])defdecorator(func:F)->F:func_doc=func.__doc__iffunc.__doc__else""deprecation_warning=(f"This function has been deprecated since version {deprecated_in}"+(f" and will be removed in version {removed_in}"ifremoved_inelse"")+".\n Please refer to the documentation for more details.")@functools.wraps(func)defwrapper(*args:Any,**kwargs:Dict[str,Any])->Callable:ifraise_exception:raiseDeprecationWarning(deprecation_warning)warnings.warn(deprecation_warning,DeprecationWarning,stacklevel=2)returnfunc(*args,**kwargs)appended_doc=f".. deprecated:: {deprecated_in}"+("\n\n\t"iflen(reasons)>0else"")forreasoninreasons:appended_doc+="\n\t- "+reasonwrapper.__doc__=f"**Deprecated function**.\n\n{func_doc}{appended_doc}"returncast(F,wrapper)returndecorator
[docs]defhash_checkpoint(checkpoint_path:Union[str,Path],output_dir:Union[str,Path])->Tuple[Path,str]:""" Hash the checkpoint file in the format of ``<filename>-<hash>.<ext>`` to be used with ``check_hash`` of :func:`torch.hub.load_state_dict_from_url`. Args: checkpoint_path: Path to the checkpoint file. output_dir: Output directory to store the hashed checkpoint file (will be created if not exist). Returns: Path to the hashed checkpoint file, the first 8 digits of SHA256 hash. .. versionadded:: 0.4.8 """ifisinstance(checkpoint_path,str):checkpoint_path=Path(checkpoint_path)ifnotcheckpoint_path.exists():raiseFileNotFoundError(f"{checkpoint_path.name} does not exist in {checkpoint_path.parent}.")ifisinstance(output_dir,str):output_dir=Path(output_dir)output_dir.mkdir(parents=True,exist_ok=True)hash_obj=hashlib.sha256()# taken from https://github.com/pytorch/vision/blob/main/references/classification/utils.pywithcheckpoint_path.open("rb")asf:# Read and update hash string value in blocks of 4KBforbyte_blockiniter(lambda:f.read(4096),b""):hash_obj.update(byte_block)sha_hash=hash_obj.hexdigest()old_filename=checkpoint_path.stemnew_filename="-".join((old_filename,sha_hash[:8]))+".pt"hash_checkpoint_path=output_dir/new_filenameshutil.move(str(checkpoint_path),hash_checkpoint_path)returnhash_checkpoint_path,sha_hash