[docs]classImageNet(ImageFolder):"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset. .. note:: Before using this class, it is required to download ImageNet 2012 dataset from `here <https://image-net.org/challenges/LSVRC/2012/2012-downloads.php>`_ and place the files ``ILSVRC2012_devkit_t12.tar.gz`` and ``ILSVRC2012_img_train.tar`` or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory. Args: root (str or ``pathlib.Path``): Root directory of the ImageNet Dataset. split (string, optional): The dataset split, supports ``train``, or ``val``. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. By default, it uses PIL as its image loader, but users could also pass in ``torchvision.io.decode_image`` for decoding image data into tensors directly. Attributes: classes (list): List of the class name tuples. class_to_idx (dict): Dict with items (class_name, class_index). wnids (list): List of the WordNet IDs. wnid_to_idx (dict): Dict with items (wordnet_id, class_index). imgs (list): List of (image path, class_index) tuples targets (list): The class_index value for each image in the dataset """def__init__(self,root:Union[str,Path],split:str="train",**kwargs:Any)->None:root=self.root=os.path.expanduser(root)self.split=verify_str_arg(split,"split",("train","val"))self.parse_archives()wnid_to_classes=load_meta_file(self.root)[0]super().__init__(self.split_folder,**kwargs)self.root=rootself.wnids=self.classesself.wnid_to_idx=self.class_to_idxself.classes=[wnid_to_classes[wnid]forwnidinself.wnids]self.class_to_idx={cls:idxforidx,clssinenumerate(self.classes)forclsinclss}defparse_archives(self)->None:ifnotcheck_integrity(os.path.join(self.root,META_FILE)):parse_devkit_archive(self.root)ifnotos.path.isdir(self.split_folder):ifself.split=="train":parse_train_archive(self.root)elifself.split=="val":parse_val_archive(self.root)@propertydefsplit_folder(self)->str:returnos.path.join(self.root,self.split)defextra_repr(self)->str:return"Split: {split}".format(**self.__dict__)
defload_meta_file(root:Union[str,Path],file:Optional[str]=None)->Tuple[Dict[str,str],List[str]]:iffileisNone:file=META_FILEfile=os.path.join(root,file)ifcheck_integrity(file):returntorch.load(file,weights_only=True)else:msg=("The meta file {} is not present in the root directory or is corrupted. ""This file is automatically created by the ImageNet dataset.")raiseRuntimeError(msg.format(file,root))def_verify_archive(root:Union[str,Path],file:str,md5:str)->None:ifnotcheck_integrity(os.path.join(root,file),md5):msg=("The archive {} is not present in the root directory or is corrupted. ""You need to download it externally and place it in {}.")raiseRuntimeError(msg.format(file,root))defparse_devkit_archive(root:Union[str,Path],file:Optional[str]=None)->None:"""Parse the devkit archive of the ImageNet2012 classification dataset and save the meta information in a binary file. Args: root (str or ``pathlib.Path``): Root directory containing the devkit archive file (str, optional): Name of devkit archive. Defaults to 'ILSVRC2012_devkit_t12.tar.gz' """importscipy.ioassiodefparse_meta_mat(devkit_root:str)->Tuple[Dict[int,str],Dict[str,Tuple[str,...]]]:metafile=os.path.join(devkit_root,"data","meta.mat")meta=sio.loadmat(metafile,squeeze_me=True)["synsets"]nums_children=list(zip(*meta))[4]meta=[meta[idx]foridx,num_childreninenumerate(nums_children)ifnum_children==0]idcs,wnids,classes=list(zip(*meta))[:3]classes=[tuple(clss.split(", "))forclssinclasses]idx_to_wnid={idx:wnidforidx,wnidinzip(idcs,wnids)}wnid_to_classes={wnid:clssforwnid,clssinzip(wnids,classes)}returnidx_to_wnid,wnid_to_classesdefparse_val_groundtruth_txt(devkit_root:str)->List[int]:file=os.path.join(devkit_root,"data","ILSVRC2012_validation_ground_truth.txt")withopen(file)astxtfh:val_idcs=txtfh.readlines()return[int(val_idx)forval_idxinval_idcs]@contextmanagerdefget_tmp_dir()->Iterator[str]:tmp_dir=tempfile.mkdtemp()try:yieldtmp_dirfinally:shutil.rmtree(tmp_dir)archive_meta=ARCHIVE_META["devkit"]iffileisNone:file=archive_meta[0]md5=archive_meta[1]_verify_archive(root,file,md5)withget_tmp_dir()astmp_dir:extract_archive(os.path.join(root,file),tmp_dir)devkit_root=os.path.join(tmp_dir,"ILSVRC2012_devkit_t12")idx_to_wnid,wnid_to_classes=parse_meta_mat(devkit_root)val_idcs=parse_val_groundtruth_txt(devkit_root)val_wnids=[idx_to_wnid[idx]foridxinval_idcs]torch.save((wnid_to_classes,val_wnids),os.path.join(root,META_FILE))defparse_train_archive(root:Union[str,Path],file:Optional[str]=None,folder:str="train")->None:"""Parse the train images archive of the ImageNet2012 classification dataset and prepare it for usage with the ImageNet dataset. Args: root (str or ``pathlib.Path``): Root directory containing the train images archive file (str, optional): Name of train images archive. Defaults to 'ILSVRC2012_img_train.tar' folder (str, optional): Optional name for train images folder. Defaults to 'train' """archive_meta=ARCHIVE_META["train"]iffileisNone:file=archive_meta[0]md5=archive_meta[1]_verify_archive(root,file,md5)train_root=os.path.join(root,folder)extract_archive(os.path.join(root,file),train_root)archives=[os.path.join(train_root,archive)forarchiveinos.listdir(train_root)]forarchiveinarchives:extract_archive(archive,os.path.splitext(archive)[0],remove_finished=True)defparse_val_archive(root:Union[str,Path],file:Optional[str]=None,wnids:Optional[List[str]]=None,folder:str="val")->None:"""Parse the validation images archive of the ImageNet2012 classification dataset and prepare it for usage with the ImageNet dataset. Args: root (str or ``pathlib.Path``): Root directory containing the validation images archive file (str, optional): Name of validation images archive. Defaults to 'ILSVRC2012_img_val.tar' wnids (list, optional): List of WordNet IDs of the validation images. If None is given, the IDs are loaded from the meta file in the root directory folder (str, optional): Optional name for validation images folder. Defaults to 'val' """archive_meta=ARCHIVE_META["val"]iffileisNone:file=archive_meta[0]md5=archive_meta[1]ifwnidsisNone:wnids=load_meta_file(root)[1]_verify_archive(root,file,md5)val_root=os.path.join(root,folder)extract_archive(os.path.join(root,file),val_root)images=sorted(os.path.join(val_root,image)forimageinos.listdir(val_root))forwnidinset(wnids):os.mkdir(os.path.join(val_root,wnid))forwnid,img_fileinzip(wnids,images):shutil.move(img_file,os.path.join(val_root,wnid,os.path.basename(img_file)))
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.