[docs]classMNIST(VisionDataset):"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset. Args: root (str or ``pathlib.Path``): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte`` and ``MNIST/raw/t10k-images-idx3-ubyte`` exist. train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, otherwise from ``t10k-images-idx3-ubyte``. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """mirrors=["http://yann.lecun.com/exdb/mnist/","https://ossci-datasets.s3.amazonaws.com/mnist/",]resources=[("train-images-idx3-ubyte.gz","f68b3c2dcbeaaa9fbdd348bbdeb94873"),("train-labels-idx1-ubyte.gz","d53e105ee54ea40749a09fcbcd1e9432"),("t10k-images-idx3-ubyte.gz","9fb629c4189551a2d022fa330f9573f3"),("t10k-labels-idx1-ubyte.gz","ec29112dd5afa0611ce80d1b7f02629c"),]training_file="training.pt"test_file="test.pt"classes=["0 - zero","1 - one","2 - two","3 - three","4 - four","5 - five","6 - six","7 - seven","8 - eight","9 - nine",]@propertydeftrain_labels(self):warnings.warn("train_labels has been renamed targets")returnself.targets@propertydeftest_labels(self):warnings.warn("test_labels has been renamed targets")returnself.targets@propertydeftrain_data(self):warnings.warn("train_data has been renamed data")returnself.data@propertydeftest_data(self):warnings.warn("test_data has been renamed data")returnself.datadef__init__(self,root:Union[str,Path],train:bool=True,transform:Optional[Callable]=None,target_transform:Optional[Callable]=None,download:bool=False,)->None:super().__init__(root,transform=transform,target_transform=target_transform)self.train=train# training set or test setifself._check_legacy_exist():self.data,self.targets=self._load_legacy_data()returnifdownload:self.download()ifnotself._check_exists():raiseRuntimeError("Dataset not found. You can use download=True to download it")self.data,self.targets=self._load_data()def_check_legacy_exist(self):processed_folder_exists=os.path.exists(self.processed_folder)ifnotprocessed_folder_exists:returnFalsereturnall(check_integrity(os.path.join(self.processed_folder,file))forfilein(self.training_file,self.test_file))def_load_legacy_data(self):# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data# directly.data_file=self.training_fileifself.trainelseself.test_filereturntorch.load(os.path.join(self.processed_folder,data_file),weights_only=True)def_load_data(self):image_file=f"{'train'ifself.trainelse't10k'}-images-idx3-ubyte"data=read_image_file(os.path.join(self.raw_folder,image_file))label_file=f"{'train'ifself.trainelse't10k'}-labels-idx1-ubyte"targets=read_label_file(os.path.join(self.raw_folder,label_file))returndata,targets
[docs]def__getitem__(self,index:int)->Tuple[Any,Any]:""" Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """img,target=self.data[index],int(self.targets[index])# doing this so that it is consistent with all other datasets# to return a PIL Imageimg=Image.fromarray(img.numpy(),mode="L")ifself.transformisnotNone:img=self.transform(img)ifself.target_transformisnotNone:target=self.target_transform(target)returnimg,target
def__len__(self)->int:returnlen(self.data)@propertydefraw_folder(self)->str:returnos.path.join(self.root,self.__class__.__name__,"raw")@propertydefprocessed_folder(self)->str:returnos.path.join(self.root,self.__class__.__name__,"processed")@propertydefclass_to_idx(self)->Dict[str,int]:return{_class:ifori,_classinenumerate(self.classes)}def_check_exists(self)->bool:returnall(check_integrity(os.path.join(self.raw_folder,os.path.splitext(os.path.basename(url))[0]))forurl,_inself.resources)defdownload(self)->None:"""Download the MNIST data if it doesn't exist already."""ifself._check_exists():returnos.makedirs(self.raw_folder,exist_ok=True)# download filesforfilename,md5inself.resources:errors=[]formirrorinself.mirrors:url=f"{mirror}{filename}"try:download_and_extract_archive(url,download_root=self.raw_folder,filename=filename,md5=md5)exceptURLErrorase:errors.append(e)continuebreakelse:s=f"Error downloading {filename}:\n"formirror,errinzip(self.mirrors,errors):s+=f"Tried {mirror}, got:\n{str(err)}\n"raiseRuntimeError(s)defextra_repr(self)->str:split="Train"ifself.trainisTrueelse"Test"returnf"Split: {split}"
[docs]classFashionMNIST(MNIST):"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset. Args: root (str or ``pathlib.Path``): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte`` and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist. train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, otherwise from ``t10k-images-idx3-ubyte``. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """mirrors=["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]resources=[("train-images-idx3-ubyte.gz","8d4fb7e6c68d591d4c3dfef9ec88bf0d"),("train-labels-idx1-ubyte.gz","25c81989df183df01b3e8a0aad5dffbe"),("t10k-images-idx3-ubyte.gz","bef4ecab320f06d8554ea6380940ec79"),("t10k-labels-idx1-ubyte.gz","bb300cfdad3c16e7a12a480ee83cd310"),]classes=["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot"]
[docs]classKMNIST(MNIST):"""`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset. Args: root (str or ``pathlib.Path``): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte`` and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist. train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, otherwise from ``t10k-images-idx3-ubyte``. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """mirrors=["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]resources=[("train-images-idx3-ubyte.gz","bdb82020997e1d708af4cf47b453dcf7"),("train-labels-idx1-ubyte.gz","e144d726b3acfaa3e44228e80efcd344"),("t10k-images-idx3-ubyte.gz","5c965bf0a639b31b8f53240b1b52f4d7"),("t10k-labels-idx1-ubyte.gz","7320c461ea6c1c855c0b718fb2a4b134"),]classes=["o","ki","su","tsu","na","ha","ma","ya","re","wo"]
[docs]classEMNIST(MNIST):"""`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset. Args: root (str or ``pathlib.Path``): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte`` and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist. split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies which one to use. train (bool, optional): If True, creates dataset from ``training.pt``, otherwise from ``test.pt``. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """url="https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip"md5="58c8d27c78d21e728a6bc7b3cc06412e"splits=("byclass","bymerge","balanced","letters","digits","mnist")# Merged Classes assumes Same structure for both uppercase and lowercase version_merged_classes={"c","i","j","k","l","m","o","p","s","u","v","w","x","y","z"}_all_classes=set(string.digits+string.ascii_letters)classes_split_dict={"byclass":sorted(list(_all_classes)),"bymerge":sorted(list(_all_classes-_merged_classes)),"balanced":sorted(list(_all_classes-_merged_classes)),"letters":["N/A"]+list(string.ascii_lowercase),"digits":list(string.digits),"mnist":list(string.digits),}def__init__(self,root:Union[str,Path],split:str,**kwargs:Any)->None:self.split=verify_str_arg(split,"split",self.splits)self.training_file=self._training_file(split)self.test_file=self._test_file(split)super().__init__(root,**kwargs)self.classes=self.classes_split_dict[self.split]@staticmethoddef_training_file(split)->str:returnf"training_{split}.pt"@staticmethoddef_test_file(split)->str:returnf"test_{split}.pt"@propertydef_file_prefix(self)->str:returnf"emnist-{self.split}-{'train'ifself.trainelse'test'}"@propertydefimages_file(self)->str:returnos.path.join(self.raw_folder,f"{self._file_prefix}-images-idx3-ubyte")@propertydeflabels_file(self)->str:returnos.path.join(self.raw_folder,f"{self._file_prefix}-labels-idx1-ubyte")def_load_data(self):returnread_image_file(self.images_file),read_label_file(self.labels_file)def_check_exists(self)->bool:returnall(check_integrity(file)forfilein(self.images_file,self.labels_file))defdownload(self)->None:"""Download the EMNIST data if it doesn't exist already."""ifself._check_exists():returnos.makedirs(self.raw_folder,exist_ok=True)download_and_extract_archive(self.url,download_root=self.raw_folder,md5=self.md5)gzip_folder=os.path.join(self.raw_folder,"gzip")forgzip_fileinos.listdir(gzip_folder):ifgzip_file.endswith(".gz"):extract_archive(os.path.join(gzip_folder,gzip_file),self.raw_folder)shutil.rmtree(gzip_folder)
[docs]classQMNIST(MNIST):"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset. Args: root (str or ``pathlib.Path``): Root directory of dataset whose ``raw`` subdir contains binary files of the datasets. what (string,optional): Can be 'train', 'test', 'test10k', 'test50k', or 'nist' for respectively the mnist compatible training set, the 60k qmnist testing set, the 10k qmnist examples that match the mnist testing set, the 50k remaining qmnist testing examples, or all the nist digits. The default is to select 'train' or 'test' according to the compatibility argument 'train'. compat (bool,optional): A boolean that says whether the target for each example is class number (for compatibility with the MNIST dataloader) or a torch vector containing the full qmnist information. Default=True. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. train (bool,optional,compatibility): When argument 'what' is not specified, this boolean decides whether to load the training set or the testing set. Default: True. """subsets={"train":"train","test":"test","test10k":"test","test50k":"test","nist":"nist"}resources:Dict[str,List[Tuple[str,str]]]={# type: ignore[assignment]"train":[("https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz","ed72d4157d28c017586c42bc6afe6370",),("https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz","0058f8dd561b90ffdd0f734c6a30e5e4",),],"test":[("https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz","1394631089c404de565df7b7aeaf9412",),("https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz","5b5b05890a5e13444e108efe57b788aa",),],"nist":[("https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz","7f124b3b8ab81486c9d8c2749c17f834",),("https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz","5ed0e788978e45d4a8bd4b7caec3d79d",),],}classes=["0 - zero","1 - one","2 - two","3 - three","4 - four","5 - five","6 - six","7 - seven","8 - eight","9 - nine",]def__init__(self,root:Union[str,Path],what:Optional[str]=None,compat:bool=True,train:bool=True,**kwargs:Any)->None:ifwhatisNone:what="train"iftrainelse"test"self.what=verify_str_arg(what,"what",tuple(self.subsets.keys()))self.compat=compatself.data_file=what+".pt"self.training_file=self.data_fileself.test_file=self.data_filesuper().__init__(root,train,**kwargs)@propertydefimages_file(self)->str:(url,_),_=self.resources[self.subsets[self.what]]returnos.path.join(self.raw_folder,os.path.splitext(os.path.basename(url))[0])@propertydeflabels_file(self)->str:_,(url,_)=self.resources[self.subsets[self.what]]returnos.path.join(self.raw_folder,os.path.splitext(os.path.basename(url))[0])def_check_exists(self)->bool:returnall(check_integrity(file)forfilein(self.images_file,self.labels_file))def_load_data(self):data=read_sn3_pascalvincent_tensor(self.images_file)ifdata.dtype!=torch.uint8:raiseTypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")ifdata.ndimension()!=3:raiseValueError("data should have 3 dimensions instead of {data.ndimension()}")targets=read_sn3_pascalvincent_tensor(self.labels_file).long()iftargets.ndimension()!=2:raiseValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")ifself.what=="test10k":data=data[0:10000,:,:].clone()targets=targets[0:10000,:].clone()elifself.what=="test50k":data=data[10000:,:,:].clone()targets=targets[10000:,:].clone()returndata,targetsdefdownload(self)->None:"""Download the QMNIST data if it doesn't exist already. Note that we only download what has been asked for (argument 'what'). """ifself._check_exists():returnos.makedirs(self.raw_folder,exist_ok=True)split=self.resources[self.subsets[self.what]]forurl,md5insplit:download_and_extract_archive(url,self.raw_folder,md5=md5)
[docs]def__getitem__(self,index:int)->Tuple[Any,Any]:# redefined to handle the compat flagimg,target=self.data[index],self.targets[index]img=Image.fromarray(img.numpy(),mode="L")ifself.transformisnotNone:img=self.transform(img)ifself.compat:target=int(target[0])ifself.target_transformisnotNone:target=self.target_transform(target)returnimg,target
defget_int(b:bytes)->int:returnint(codecs.encode(b,"hex"),16)SN3_PASCALVINCENT_TYPEMAP={8:torch.uint8,9:torch.int8,11:torch.int16,12:torch.int32,13:torch.float32,14:torch.float64,}defread_sn3_pascalvincent_tensor(path:str,strict:bool=True)->torch.Tensor:"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). Argument may be a filename, compressed filename, or file object. """# readwithopen(path,"rb")asf:data=f.read()# parseifsys.byteorder=="little":magic=get_int(data[0:4])nd=magic%256ty=magic//256else:nd=get_int(data[0:1])ty=get_int(data[1:2])+get_int(data[2:3])*256+get_int(data[3:4])*256*256assert1<=nd<=3assert8<=ty<=14torch_type=SN3_PASCALVINCENT_TYPEMAP[ty]s=[get_int(data[4*(i+1):4*(i+2)])foriinrange(nd)]ifsys.byteorder=="big":foriinrange(len(s)):s[i]=int.from_bytes(s[i].to_bytes(4,byteorder="little"),byteorder="big",signed=False)parsed=torch.frombuffer(bytearray(data),dtype=torch_type,offset=(4*(nd+1)))# The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case# that is little endian and the dtype has more than one byte, we need to flip them.ifsys.byteorder=="little"andparsed.element_size()>1:parsed=_flip_byte_order(parsed)assertparsed.shape[0]==np.prod(s)ornotstrictreturnparsed.view(*s)defread_label_file(path:str)->torch.Tensor:x=read_sn3_pascalvincent_tensor(path,strict=False)ifx.dtype!=torch.uint8:raiseTypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")ifx.ndimension()!=1:raiseValueError(f"x should have 1 dimension instead of {x.ndimension()}")returnx.long()defread_image_file(path:str)->torch.Tensor:x=read_sn3_pascalvincent_tensor(path,strict=False)ifx.dtype!=torch.uint8:raiseTypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")ifx.ndimension()!=3:raiseValueError(f"x should have 3 dimension instead of {x.ndimension()}")returnx
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.