Source code for torchvision.datasets.stanford_cars
import pathlib
from typing import Any, Callable, Optional, Union
from .folder import default_loader
from .utils import verify_str_arg
from .vision import VisionDataset
[docs]class StanfordCars(VisionDataset):
"""Stanford Cars Dataset
The Cars dataset contains 16,185 images of 196 classes of cars. The data is
split into 8,144 training images and 8,041 testing images, where each class
has been split roughly in a 50-50 split
The original URL is https://ai.stanford.edu/~jkrause/cars/car_dataset.html,
the dataset isn't available online anymore.
.. note::
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
Args:
root (str or ``pathlib.Path``): Root directory of dataset
split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): This parameter exists for backward compatibility but it does not
download the dataset, since the original URL is not available anymore.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""
def __init__(
self,
root: Union[str, pathlib.Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Callable[[str], Any] = default_loader,
) -> None:
try:
import scipy.io as sio
except ImportError:
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "test"))
self._base_folder = pathlib.Path(root) / "stanford_cars"
devkit = self._base_folder / "devkit"
if self._split == "train":
self._annotations_mat_path = devkit / "cars_train_annos.mat"
self._images_base_path = self._base_folder / "cars_train"
else:
self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
self._images_base_path = self._base_folder / "cars_test"
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found.")
self._samples = [
(
str(self._images_base_path / annotation["fname"]),
annotation["class"] - 1, # Original target mapping starts from 1, hence -1
)
for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
]
self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
self.loader = loader
def __len__(self) -> int:
return len(self._samples)
[docs] def __getitem__(self, idx: int) -> tuple[Any, Any]:
"""Returns pil_image and class_id for given index"""
image_path, target = self._samples[idx]
image = self.loader(image_path)
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image, target
def _check_exists(self) -> bool:
if not (self._base_folder / "devkit").is_dir():
return False
return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
def download(self):
raise ValueError("The original URL is broken so the StanfordCars dataset cannot be downloaded anymore.")