importosimportos.pathfrompathlibimportPathfromtypingimportAny,Callable,cast,Dict,List,Optional,Tuple,UnionfromPILimportImagefrom.visionimportVisionDatasetdefhas_file_allowed_extension(filename:str,extensions:Union[str,Tuple[str,...]])->bool:"""Checks if a file is an allowed extension. Args: filename (string): path to a file extensions (tuple of strings): extensions to consider (lowercase) Returns: bool: True if the filename ends with one of given extensions """returnfilename.lower().endswith(extensionsifisinstance(extensions,str)elsetuple(extensions))defis_image_file(filename:str)->bool:"""Checks if a file is an allowed image extension. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """returnhas_file_allowed_extension(filename,IMG_EXTENSIONS)deffind_classes(directory:Union[str,Path])->Tuple[List[str],Dict[str,int]]:"""Finds the class folders in a dataset. See :class:`DatasetFolder` for details. """classes=sorted(entry.nameforentryinos.scandir(directory)ifentry.is_dir())ifnotclasses:raiseFileNotFoundError(f"Couldn't find any class folder in {directory}.")class_to_idx={cls_name:ifori,cls_nameinenumerate(classes)}returnclasses,class_to_idxdefmake_dataset(directory:Union[str,Path],class_to_idx:Optional[Dict[str,int]]=None,extensions:Optional[Union[str,Tuple[str,...]]]=None,is_valid_file:Optional[Callable[[str],bool]]=None,allow_empty:bool=False,)->List[Tuple[str,int]]:"""Generates a list of samples of a form (path_to_sample, class). See :class:`DatasetFolder` for details. Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function by default. """directory=os.path.expanduser(directory)ifclass_to_idxisNone:_,class_to_idx=find_classes(directory)elifnotclass_to_idx:raiseValueError("'class_to_index' must have at least one entry to collect any samples.")both_none=extensionsisNoneandis_valid_fileisNoneboth_something=extensionsisnotNoneandis_valid_fileisnotNoneifboth_noneorboth_something:raiseValueError("Both extensions and is_valid_file cannot be None or not None at the same time")ifextensionsisnotNone:defis_valid_file(x:str)->bool:returnhas_file_allowed_extension(x,extensions)# type: ignore[arg-type]is_valid_file=cast(Callable[[str],bool],is_valid_file)instances=[]available_classes=set()fortarget_classinsorted(class_to_idx.keys()):class_index=class_to_idx[target_class]target_dir=os.path.join(directory,target_class)ifnotos.path.isdir(target_dir):continueforroot,_,fnamesinsorted(os.walk(target_dir,followlinks=True)):forfnameinsorted(fnames):path=os.path.join(root,fname)ifis_valid_file(path):item=path,class_indexinstances.append(item)iftarget_classnotinavailable_classes:available_classes.add(target_class)empty_classes=set(class_to_idx.keys())-available_classesifempty_classesandnotallow_empty:msg=f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "ifextensionsisnotNone:msg+=f"Supported extensions are: {extensionsifisinstance(extensions,str)else', '.join(extensions)}"raiseFileNotFoundError(msg)returninstances
[docs]classDatasetFolder(VisionDataset):"""A generic data loader. This default directory structure can be customized by overriding the :meth:`find_classes` method. Args: root (str or ``pathlib.Path``): Root directory path. loader (callable): A function to load a sample given its path. extensions (tuple[string]): A list of allowed extensions. both extensions and is_valid_file should not be passed. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. is_valid_file (callable, optional): A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. allow_empty(bool, optional): If True, empty folders are considered to be valid classes. An error is raised on empty folders if False (default). Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples targets (list): The class_index value for each image in the dataset """def__init__(self,root:Union[str,Path],loader:Callable[[str],Any],extensions:Optional[Tuple[str,...]]=None,transform:Optional[Callable]=None,target_transform:Optional[Callable]=None,is_valid_file:Optional[Callable[[str],bool]]=None,allow_empty:bool=False,)->None:super().__init__(root,transform=transform,target_transform=target_transform)classes,class_to_idx=self.find_classes(self.root)samples=self.make_dataset(self.root,class_to_idx=class_to_idx,extensions=extensions,is_valid_file=is_valid_file,allow_empty=allow_empty,)self.loader=loaderself.extensions=extensionsself.classes=classesself.class_to_idx=class_to_idxself.samples=samplesself.targets=[s[1]forsinsamples]
[docs]@staticmethoddefmake_dataset(directory:Union[str,Path],class_to_idx:Dict[str,int],extensions:Optional[Tuple[str,...]]=None,is_valid_file:Optional[Callable[[str],bool]]=None,allow_empty:bool=False,)->List[Tuple[str,int]]:"""Generates a list of samples of a form (path_to_sample, class). This can be overridden to e.g. read files from a compressed zip file instead of from the disk. Args: directory (str): root dataset directory, corresponding to ``self.root``. class_to_idx (Dict[str, int]): Dictionary mapping class name to class index. extensions (optional): A list of allowed extensions. Either extensions or is_valid_file should be passed. Defaults to None. is_valid_file (optional): A function that takes path of a file and checks if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. Defaults to None. allow_empty(bool, optional): If True, empty folders are considered to be valid classes. An error is raised on empty folders if False (default). Raises: ValueError: In case ``class_to_idx`` is empty. ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. FileNotFoundError: In case no valid file was found for any class. Returns: List[Tuple[str, int]]: samples of a form (path_to_sample, class) """ifclass_to_idxisNone:# prevent potential bug since make_dataset() would use the class_to_idx logic of the# find_classes() function, instead of using that of the find_classes() method, which# is potentially overridden and thus could have a different logic.raiseValueError("The class_to_idx parameter cannot be None.")returnmake_dataset(directory,class_to_idx,extensions=extensions,is_valid_file=is_valid_file,allow_empty=allow_empty)
[docs]deffind_classes(self,directory:Union[str,Path])->Tuple[List[str],Dict[str,int]]:"""Find the class folders in a dataset structured as follows:: directory/ ├── class_x │ ├── xxx.ext │ ├── xxy.ext │ └── ... │ └── xxz.ext └── class_y ├── 123.ext ├── nsdf3.ext └── ... └── asd932_.ext This method can be overridden to only consider a subset of classes, or to adapt to a different dataset directory structure. Args: directory(str): Root directory path, corresponding to ``self.root`` Raises: FileNotFoundError: If ``dir`` has no class folders. Returns: (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index. """returnfind_classes(directory)
def__getitem__(self,index:int)->Tuple[Any,Any]:""" Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """path,target=self.samples[index]sample=self.loader(path)ifself.transformisnotNone:sample=self.transform(sample)ifself.target_transformisnotNone:target=self.target_transform(target)returnsample,targetdef__len__(self)->int:returnlen(self.samples)
IMG_EXTENSIONS=(".jpg",".jpeg",".png",".ppm",".bmp",".pgm",".tif",".tiff",".webp")defpil_loader(path:Union[str,Path])->Image.Image:# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)withopen(path,"rb")asf:img=Image.open(f)returnimg.convert("RGB")# TODO: specify the return typedefaccimage_loader(path:Union[str,Path])->Any:importaccimagetry:returnaccimage.Image(path)exceptOSError:# Potentially a decoding problem, fall back to PIL.Imagereturnpil_loader(path)defdefault_loader(path:Union[str,Path])->Any:fromtorchvisionimportget_image_backendifget_image_backend()=="accimage":returnaccimage_loader(path)else:returnpil_loader(path)
[docs]classImageFolder(DatasetFolder):"""A generic data loader where the images are arranged in this way by default: :: root/dog/xxx.png root/dog/xxy.png root/dog/[...]/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/[...]/asd932_.png This class inherits from :class:`~torchvision.datasets.DatasetFolder` so the same methods can be overridden to customize the dataset. Args: root (str or ``pathlib.Path``): Root directory path. transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader, and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files) allow_empty(bool, optional): If True, empty folders are considered to be valid classes. An error is raised on empty folders if False (default). Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """def__init__(self,root:Union[str,Path],transform:Optional[Callable]=None,target_transform:Optional[Callable]=None,loader:Callable[[str],Any]=default_loader,is_valid_file:Optional[Callable[[str],bool]]=None,allow_empty:bool=False,):super().__init__(root,loader,IMG_EXTENSIONSifis_valid_fileisNoneelseNone,transform=transform,target_transform=target_transform,is_valid_file=is_valid_file,allow_empty=allow_empty,)self.imgs=self.samples
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.