Shortcuts

Source code for torchvision.datasets.caltech

import os
import os.path
from typing import Any, Callable, List, Optional, Tuple, Union

from PIL import Image

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset


[docs]class Caltech101(VisionDataset): """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset. .. warning:: This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format. Args: root (string): Root directory of dataset where directory ``caltech101`` exists or will be saved to if download is set to True. target_type (string or list, optional): Type of target to use, ``category`` or ``annotation``. Can also be a list to output a tuple with all specified target types. ``category`` represents the target class, and ``annotation`` is a list of points from a hand-generated outline. Defaults to ``category``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. .. warning:: To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required. """ def __init__( self, root: str, target_type: Union[List[str], str] = "category", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform) os.makedirs(self.root, exist_ok=True) if isinstance(target_type, str): target_type = [target_type] self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type] if download: self.download() if not self._check_integrity(): raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) self.categories.remove("BACKGROUND_Google") # this is not a real class # For some reason, the category names in "101_ObjectCategories" and # "Annotations" do not always match. This is a manual map between the # two. Defaults to using same name, since most names are fine. name_map = { "Faces": "Faces_2", "Faces_easy": "Faces_3", "Motorbikes": "Motorbikes_16", "airplanes": "Airplanes_Side_2", } self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) self.index: List[int] = [] self.y = [] for (i, c) in enumerate(self.categories): n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c))) self.index.extend(range(1, n + 1)) self.y.extend(n * [i])
[docs] def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (image, target) where the type of target specified by target_type. """ import scipy.io img = Image.open( os.path.join( self.root, "101_ObjectCategories", self.categories[self.y[index]], f"image_{self.index[index]:04d}.jpg", ) ) target: Any = [] for t in self.target_type: if t == "category": target.append(self.y[index]) elif t == "annotation": data = scipy.io.loadmat( os.path.join( self.root, "Annotations", self.annotation_categories[self.y[index]], f"annotation_{self.index[index]:04d}.mat", ) ) target.append(data["obj_contour"]) target = tuple(target) if len(target) > 1 else target[0] if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target
def _check_integrity(self) -> bool: # can be more robust and check hash of files return os.path.exists(os.path.join(self.root, "101_ObjectCategories")) def __len__(self) -> int: return len(self.index) def download(self) -> None: if self._check_integrity(): print("Files already downloaded and verified") return download_and_extract_archive( "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp", self.root, filename="101_ObjectCategories.tar.gz", md5="b224c7392d521a49829488ab0f1120d9", ) download_and_extract_archive( "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m", self.root, filename="Annotations.tar", md5="6f83eeb1f24d99cab4eb377263132c91", ) def extra_repr(self) -> str: return "Target type: {target_type}".format(**self.__dict__)
[docs]class Caltech256(VisionDataset): """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset. Args: root (string): Root directory of dataset where directory ``caltech256`` exists or will be saved to if download is set to True. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform) os.makedirs(self.root, exist_ok=True) if download: self.download() if not self._check_integrity(): raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) self.index: List[int] = [] self.y = [] for (i, c) in enumerate(self.categories): n = len( [ item for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c)) if item.endswith(".jpg") ] ) self.index.extend(range(1, n + 1)) self.y.extend(n * [i])
[docs] def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img = Image.open( os.path.join( self.root, "256_ObjectCategories", self.categories[self.y[index]], f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg", ) ) target = self.y[index] if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target
def _check_integrity(self) -> bool: # can be more robust and check hash of files return os.path.exists(os.path.join(self.root, "256_ObjectCategories")) def __len__(self) -> int: return len(self.index) def download(self) -> None: if self._check_integrity(): print("Files already downloaded and verified") return download_and_extract_archive( "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK", self.root, filename="256_ObjectCategories.tar", md5="67b4f42ca05d46448c6bb8ecd2220f6d", )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources