[docs]classCocoDetection(VisionDataset):"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset. It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed, which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. annFile (string): Path to json annotation file. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. """def__init__(self,root:Union[str,Path],annFile:str,transform:Optional[Callable]=None,target_transform:Optional[Callable]=None,transforms:Optional[Callable]=None,)->None:super().__init__(root,transforms,transform,target_transform)frompycocotools.cocoimportCOCOself.coco=COCO(annFile)self.ids=list(sorted(self.coco.imgs.keys()))def_load_image(self,id:int)->Image.Image:path=self.coco.loadImgs(id)[0]["file_name"]returnImage.open(os.path.join(self.root,path)).convert("RGB")def_load_target(self,id:int)->List[Any]:returnself.coco.loadAnns(self.coco.getAnnIds(id))
[docs]def__getitem__(self,index:int)->Tuple[Any,Any]:ifnotisinstance(index,int):raiseValueError(f"Index must be of type integer, got {type(index)} instead.")id=self.ids[index]image=self._load_image(id)target=self._load_target(id)ifself.transformsisnotNone:image,target=self.transforms(image,target)returnimage,target
def__len__(self)->int:returnlen(self.ids)
[docs]classCocoCaptions(CocoDetection):"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset. It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed, which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. annFile (string): Path to json annotation file. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. Example: .. code:: python import torchvision.datasets as dset import torchvision.transforms as transforms cap = dset.CocoCaptions(root = 'dir where images are', annFile = 'json annotation file', transform=transforms.PILToTensor()) print('Number of samples: ', len(cap)) img, target = cap[3] # load 4th sample print("Image Size: ", img.size()) print(target) Output: :: Number of samples: 82783 Image Size: (3L, 427L, 640L) [u'A plane emitting smoke stream flying over a mountain.', u'A plane darts across a bright blue sky behind a mountain covered in snow', u'A plane leaves a contrail above the snowy mountain top.', u'A mountain that has a plane flying overheard in the distance.', u'A mountain view with a plume of smoke in the background'] """def_load_target(self,id:int)->List[str]:return[ann["caption"]foranninsuper()._load_target(id)]
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.