Shortcuts

Source code for torchvision.datasets.kitti

import csv
import os
from typing import Any, Callable, List, Optional, Tuple

from PIL import Image

from .utils import download_and_extract_archive
from .vision import VisionDataset


[docs]class Kitti(VisionDataset): """`KITTI <http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark>`_ Dataset. It corresponds to the "left color images of object" dataset, for object detection. Args: root (string): Root directory where images are downloaded to. Expects the following folder structure if download=False: .. code:: <root> └── Kitti └─ raw ├── training | ├── image_2 | └── label_2 └── testing └── image_2 train (bool, optional): Use ``train`` split if true, else ``test`` split. Defaults to ``train``. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.ToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/" resources = [ "data_object_image_2.zip", "data_object_label_2.zip", ] image_dir_name = "image_2" labels_dir_name = "label_2" def __init__( self, root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None, download: bool = False, ): super().__init__( root, transform=transform, target_transform=target_transform, transforms=transforms, ) self.images = [] self.targets = [] self.root = root self.train = train self._location = "training" if self.train else "testing" if download: self.download() if not self._check_exists(): raise RuntimeError( "Dataset not found. You may use download=True to download it." ) image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name) if self.train: labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name) for img_file in os.listdir(image_dir): self.images.append(os.path.join(image_dir, img_file)) if self.train: self.targets.append( os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt") )
[docs] def __getitem__(self, index: int) -> Tuple[Any, Any]: """Get item at a given index. Args: index (int): Index Returns: tuple: (image, target), where target is a list of dictionaries with the following keys: - type: str - truncated: float - occluded: int - alpha: float - bbox: float[4] - dimensions: float[3] - locations: float[3] - rotation_y: float """ image = Image.open(self.images[index]) target = self._parse_target(index) if self.train else None if self.transforms: image, target = self.transforms(image, target) return image, target
def _parse_target(self, index: int) -> List: target = [] with open(self.targets[index]) as inp: content = csv.reader(inp, delimiter=" ") for line in content: target.append({ "type": line[0], "truncated": float(line[1]), "occluded": int(line[2]), "alpha": float(line[3]), "bbox": [float(x) for x in line[4:8]], "dimensions": [float(x) for x in line[8:11]], "location": [float(x) for x in line[11:14]], "rotation_y": float(line[14]), }) return target def __len__(self) -> int: return len(self.images) @property def _raw_folder(self) -> str: return os.path.join(self.root, self.__class__.__name__, "raw") def _check_exists(self) -> bool: """Check if the data directory exists.""" folders = [self.image_dir_name] if self.train: folders.append(self.labels_dir_name) return all( os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders ) def download(self) -> None: """Download the KITTI data if it doesn't exist already.""" if self._check_exists(): return os.makedirs(self._raw_folder, exist_ok=True) # download files for fname in self.resources: download_and_extract_archive( url=f"{self.data_url}{fname}", download_root=self._raw_folder, filename=fname, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources