[docs]classINaturalist(VisionDataset):"""`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset. Args: root (str or ``pathlib.Path``): Root directory of dataset where the image files are stored. This class does not require/use annotation files. version (string, optional): Which version of the dataset to download/use. One of '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'. Default: `2021_train`. target_type (string or list, optional): Type of target to use, for 2021 versions, one of: - ``full``: the full category (species) - ``kingdom``: e.g. "Animalia" - ``phylum``: e.g. "Arthropoda" - ``class``: e.g. "Insecta" - ``order``: e.g. "Coleoptera" - ``family``: e.g. "Cleridae" - ``genus``: e.g. "Trichodes" for 2017-2019 versions, one of: - ``full``: the full (numeric) category - ``super``: the super category, e.g. "Amphibians" Can also be a list to output a tuple with all specified target types. Defaults to ``full``. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """def__init__(self,root:Union[str,Path],version:str="2021_train",target_type:Union[List[str],str]="full",transform:Optional[Callable]=None,target_transform:Optional[Callable]=None,download:bool=False,)->None:self.version=verify_str_arg(version,"version",DATASET_URLS.keys())super().__init__(os.path.join(root,version),transform=transform,target_transform=target_transform)os.makedirs(root,exist_ok=True)ifdownload:self.download()ifnotself._check_exists():raiseRuntimeError("Dataset not found or corrupted. You can use download=True to download it")self.all_categories:List[str]=[]# map: category type -> name of category -> indexself.categories_index:Dict[str,Dict[str,int]]={}# list indexed by category id, containing mapping from category type -> indexself.categories_map:List[Dict[str,int]]=[]ifnotisinstance(target_type,list):target_type=[target_type]ifself.version[:4]=="2021":self.target_type=[verify_str_arg(t,"target_type",("full",*CATEGORIES_2021))fortintarget_type]self._init_2021()else:self.target_type=[verify_str_arg(t,"target_type",("full","super"))fortintarget_type]self._init_pre2021()# index of all files: (full category id, filename)self.index:List[Tuple[int,str]]=[]fordir_index,dir_nameinenumerate(self.all_categories):files=os.listdir(os.path.join(self.root,dir_name))forfnameinfiles:self.index.append((dir_index,fname))def_init_2021(self)->None:"""Initialize based on 2021 layout"""self.all_categories=sorted(os.listdir(self.root))# map: category type -> name of category -> indexself.categories_index={k:{}forkinCATEGORIES_2021}fordir_index,dir_nameinenumerate(self.all_categories):pieces=dir_name.split("_")iflen(pieces)!=8:raiseRuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")ifpieces[0]!=f"{dir_index:05d}":raiseRuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")cat_map={}forcat,nameinzip(CATEGORIES_2021,pieces[1:7]):ifnameinself.categories_index[cat]:cat_id=self.categories_index[cat][name]else:cat_id=len(self.categories_index[cat])self.categories_index[cat][name]=cat_idcat_map[cat]=cat_idself.categories_map.append(cat_map)def_init_pre2021(self)->None:"""Initialize based on 2017-2019 layout"""# map: category type -> name of category -> indexself.categories_index={"super":{}}cat_index=0super_categories=sorted(os.listdir(self.root))forsindex,scatinenumerate(super_categories):self.categories_index["super"][scat]=sindexsubcategories=sorted(os.listdir(os.path.join(self.root,scat)))forsubcatinsubcategories:ifself.version=="2017":# this version does not use ids as directory namessubcat_i=cat_indexcat_index+=1else:try:subcat_i=int(subcat)exceptValueError:raiseRuntimeError(f"Unexpected non-numeric dir name: {subcat}")ifsubcat_i>=len(self.categories_map):old_len=len(self.categories_map)self.categories_map.extend([{}]*(subcat_i-old_len+1))self.all_categories.extend([""]*(subcat_i-old_len+1))ifself.categories_map[subcat_i]:raiseRuntimeError(f"Duplicate category {subcat}")self.categories_map[subcat_i]={"super":sindex}self.all_categories[subcat_i]=os.path.join(scat,subcat)# validate the dictionaryforcindex,cinenumerate(self.categories_map):ifnotc:raiseRuntimeError(f"Missing category {cindex}")
[docs]def__getitem__(self,index:int)->Tuple[Any,Any]:""" Args: index (int): Index Returns: tuple: (image, target) where the type of target specified by target_type. """cat_id,fname=self.index[index]img=Image.open(os.path.join(self.root,self.all_categories[cat_id],fname))target:Any=[]fortinself.target_type:ift=="full":target.append(cat_id)else:target.append(self.categories_map[cat_id][t])target=tuple(target)iflen(target)>1elsetarget[0]ifself.transformisnotNone:img=self.transform(img)ifself.target_transformisnotNone:target=self.target_transform(target)returnimg,target
def__len__(self)->int:returnlen(self.index)
[docs]defcategory_name(self,category_type:str,category_id:int)->str:""" Args: category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super" category_id(int): an index (class id) from this category Returns: the name of the category """ifcategory_type=="full":returnself.all_categories[category_id]else:ifcategory_typenotinself.categories_index:raiseValueError(f"Invalid category type '{category_type}'")else:forname,idinself.categories_index[category_type].items():ifid==category_id:returnnameraiseValueError(f"Invalid category id {category_id} for {category_type}")
def_check_exists(self)->bool:returnos.path.exists(self.root)andlen(os.listdir(self.root))>0defdownload(self)->None:ifself._check_exists():returnbase_root=os.path.dirname(self.root)download_and_extract_archive(DATASET_URLS[self.version],base_root,filename=f"{self.version}.tgz",md5=DATASET_MD5[self.version])orig_dir_name=os.path.join(base_root,os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))ifnotos.path.exists(orig_dir_name):raiseRuntimeError(f"Unable to find downloaded files at {orig_dir_name}")os.rename(orig_dir_name,self.root)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.