from__future__importannotationsimportcollections.abcimportnumbersfromcontextlibimportsuppressfromtypingimportAny,Callable,Dict,List,Literal,Sequence,Tuple,Type,UnionimportPIL.Imageimporttorchfromtorchvisionimporttv_tensorsfromtorchvision._utilsimportsequence_to_strfromtorchvision.transforms.transformsimport_check_sequence_input,_setup_angle,_setup_size# noqa: F401fromtorchvision.transforms.v2.functionalimportget_dimensions,get_size,is_pure_tensorfromtorchvision.transforms.v2.functional._utilsimport_FillType,_FillTypeJITdef_setup_number_or_seq(arg:Union[int,float,Sequence[Union[int,float]]],name:str)->Sequence[float]:ifnotisinstance(arg,(int,float,Sequence)):raiseTypeError(f"{name} should be a number or a sequence of numbers. Got {type(arg)}")ifisinstance(arg,Sequence)andlen(arg)notin(1,2):raiseValueError(f"If {name} is a sequence its length should be 1 or 2. Got {len(arg)}")ifisinstance(arg,Sequence):forelementinarg:ifnotisinstance(element,(int,float)):raiseValueError(f"{name} should be a sequence of numbers. Got {type(element)}")ifisinstance(arg,(int,float)):arg=[float(arg),float(arg)]elifisinstance(arg,Sequence):iflen(arg)==1:arg=[float(arg[0]),float(arg[0])]else:arg=[float(arg[0]),float(arg[1])]returnargdef_check_fill_arg(fill:Union[_FillType,Dict[Union[Type,str],_FillType]])->None:ifisinstance(fill,dict):forvalueinfill.values():_check_fill_arg(value)else:iffillisnotNoneandnotisinstance(fill,(numbers.Number,tuple,list)):raiseTypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")def_convert_fill_arg(fill:_FillType)->_FillTypeJIT:# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517# So, we can't reassign fill to 0# if fill is None:# fill = 0iffillisNone:returnfillifnotisinstance(fill,(int,float)):fill=[float(v)forvinlist(fill)]returnfill# type: ignore[return-value]def_setup_fill_arg(fill:Union[_FillType,Dict[Union[Type,str],_FillType]])->Dict[Union[Type,str],_FillTypeJIT]:_check_fill_arg(fill)ifisinstance(fill,dict):fork,vinfill.items():fill[k]=_convert_fill_arg(v)returnfill# type: ignore[return-value]else:return{"others":_convert_fill_arg(fill)}def_get_fill(fill_dict,inpt_type):ifinpt_typeinfill_dict:returnfill_dict[inpt_type]elif"others"infill_dict:returnfill_dict["others"]else:RuntimeError("This should never happen, please open an issue on the torchvision repo if you hit this.")def_check_padding_arg(padding:Union[int,Sequence[int]])->None:ifnotisinstance(padding,(numbers.Number,tuple,list)):raiseTypeError("Got inappropriate padding arg")ifisinstance(padding,(tuple,list))andlen(padding)notin[1,2,4]:raiseValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)# https://github.com/pytorch/vision/issues/6250def_check_padding_mode_arg(padding_mode:Literal["constant","edge","reflect","symmetric"])->None:ifpadding_modenotin["constant","edge","reflect","symmetric"]:raiseValueError("Padding mode should be either constant, edge, reflect or symmetric")def_find_labels_default_heuristic(inputs:Any)->torch.Tensor:""" This heuristic covers three cases: 1. The input is tuple or list whose second item is a labels tensor. This happens for already batched classification inputs for MixUp and CutMix (typically after the Dataloder). 2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor under a label-like (see below) key. This happens for the inputs of detection models. 3. The input is a dictionary that is structured as the one from 2. What is "label-like" key? We first search for an case-insensitive match of 'labels' inside the keys of the dictionary. This is the name our detection models expect. If we can't find that, we look for a case-insensitive match of the term 'label' anywhere inside the key, i.e. 'FooLaBeLBar'. If we can't find that either, the dictionary contains no "label-like" key. """ifisinstance(inputs,(tuple,list)):inputs=inputs[1]# MixUp, CutMixifis_pure_tensor(inputs):returninputsifnotisinstance(inputs,collections.abc.Mapping):raiseValueError(f"When using the default labels_getter, the input passed to forward must be a dictionary or a two-tuple "f"whose second item is a dictionary or a tensor, but got {inputs} instead.")candidate_key=Nonewithsuppress(StopIteration):candidate_key=next(keyforkeyininputs.keys()ifkey.lower()=="labels")ifcandidate_keyisNone:withsuppress(StopIteration):candidate_key=next(keyforkeyininputs.keys()if"label"inkey.lower())ifcandidate_keyisNone:raiseValueError("Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?""If there are no labels in the sample by design, pass labels_getter=None.")returninputs[candidate_key]def_parse_labels_getter(labels_getter:Union[str,Callable[[Any],Any],None])->Callable[[Any],Any]:iflabels_getter=="default":return_find_labels_default_heuristicelifcallable(labels_getter):returnlabels_gettereliflabels_getterisNone:returnlambda_:Noneelse:raiseValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")
[docs]defget_bounding_boxes(flat_inputs:List[Any])->tv_tensors.BoundingBoxes:"""Return the Bounding Boxes in the input. Assumes only one ``BoundingBoxes`` object is present. """# This assumes there is only one bbox per sample as per the general conventiontry:returnnext(inptforinptinflat_inputsifisinstance(inpt,tv_tensors.BoundingBoxes))exceptStopIteration:raiseValueError("No bounding boxes were found in the sample")
[docs]defquery_chw(flat_inputs:List[Any])->Tuple[int,int,int]:"""Return Channel, Height, and Width."""chws={tuple(get_dimensions(inpt))forinptinflat_inputsifcheck_type(inpt,(is_pure_tensor,tv_tensors.Image,PIL.Image.Image,tv_tensors.Video))}ifnotchws:raiseTypeError("No image or video was found in the sample")eliflen(chws)>1:raiseValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")c,h,w=chws.pop()returnc,h,w
[docs]defquery_size(flat_inputs:List[Any])->Tuple[int,int]:"""Return Height and Width."""sizes={tuple(get_size(inpt))forinptinflat_inputsifcheck_type(inpt,(is_pure_tensor,tv_tensors.Image,PIL.Image.Image,tv_tensors.Video,tv_tensors.Mask,tv_tensors.BoundingBoxes,),)}ifnotsizes:raiseTypeError("No image, video, mask or bounding box was found in the sample")eliflen(sizes)>1:raiseValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")h,w=sizes.pop()returnh,w
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.