Shortcuts

Source code for torchvision.datasets.coco

import os.path
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

from PIL import Image

from .vision import VisionDataset


[docs]class CocoDetection(VisionDataset): """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset. It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. annFile (string): Path to json annotation file. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. """ def __init__( self, root: Union[str, Path], annFile: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None, ) -> None: super().__init__(root, transforms, transform, target_transform) from pycocotools.coco import COCO self.coco = COCO(annFile) self.ids = list(sorted(self.coco.imgs.keys())) def _load_image(self, id: int) -> Image.Image: path = self.coco.loadImgs(id)[0]["file_name"] return Image.open(os.path.join(self.root, path)).convert("RGB") def _load_target(self, id: int) -> List[Any]: return self.coco.loadAnns(self.coco.getAnnIds(id))
[docs] def __getitem__(self, index: int) -> Tuple[Any, Any]: if not isinstance(index, int): raise ValueError(f"Index must be of type integer, got {type(index)} instead.") id = self.ids[index] image = self._load_image(id) target = self._load_target(id) if self.transforms is not None: image, target = self.transforms(image, target) return image, target
def __len__(self) -> int: return len(self.ids)
[docs]class CocoCaptions(CocoDetection): """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset. It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. annFile (string): Path to json annotation file. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. Example: .. code:: python import torchvision.datasets as dset import torchvision.transforms as transforms cap = dset.CocoCaptions(root = 'dir where images are', annFile = 'json annotation file', transform=transforms.PILToTensor()) print('Number of samples: ', len(cap)) img, target = cap[3] # load 4th sample print("Image Size: ", img.size()) print(target) Output: :: Number of samples: 82783 Image Size: (3L, 427L, 640L) [u'A plane emitting smoke stream flying over a mountain.', u'A plane darts across a bright blue sky behind a mountain covered in snow', u'A plane leaves a contrail above the snowy mountain top.', u'A mountain that has a plane flying overheard in the distance.', u'A mountain view with a plume of smoke in the background'] """ def _load_target(self, id: int) -> List[str]: return [ann["caption"] for ann in super()._load_target(id)]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources