importglobimportosfromcollectionsimportdefaultdictfromhtml.parserimportHTMLParserfrompathlibimportPathfromtypingimportAny,Callable,Dict,List,Optional,Tuple,UnionfromPILimportImagefrom.visionimportVisionDatasetclassFlickr8kParser(HTMLParser):"""Parser for extracting captions from the Flickr8k dataset web page."""def__init__(self,root:Union[str,Path])->None:super().__init__()self.root=root# Data structure to store captionsself.annotations:Dict[str,List[str]]={}# State variablesself.in_table=Falseself.current_tag:Optional[str]=Noneself.current_img:Optional[str]=Nonedefhandle_starttag(self,tag:str,attrs:List[Tuple[str,Optional[str]]])->None:self.current_tag=tagiftag=="table":self.in_table=Truedefhandle_endtag(self,tag:str)->None:self.current_tag=Noneiftag=="table":self.in_table=Falsedefhandle_data(self,data:str)->None:ifself.in_table:ifdata=="Image Not Found":self.current_img=Noneelifself.current_tag=="a":img_id=data.split("/")[-2]img_id=os.path.join(self.root,img_id+"_*.jpg")img_id=glob.glob(img_id)[0]self.current_img=img_idself.annotations[img_id]=[]elifself.current_tag=="li"andself.current_img:img_id=self.current_imgself.annotations[img_id].append(data.strip())
[docs]classFlickr8k(VisionDataset):"""`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. ann_file (string): Path to annotation file. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """def__init__(self,root:Union[str,Path],ann_file:str,transform:Optional[Callable]=None,target_transform:Optional[Callable]=None,)->None:super().__init__(root,transform=transform,target_transform=target_transform)self.ann_file=os.path.expanduser(ann_file)# Read annotations and store in a dictparser=Flickr8kParser(self.root)withopen(self.ann_file)asfh:parser.feed(fh.read())self.annotations=parser.annotationsself.ids=list(sorted(self.annotations.keys()))
[docs]def__getitem__(self,index:int)->Tuple[Any,Any]:""" Args: index (int): Index Returns: tuple: Tuple (image, target). target is a list of captions for the image. """img_id=self.ids[index]# Imageimg=Image.open(img_id).convert("RGB")ifself.transformisnotNone:img=self.transform(img)# Captionstarget=self.annotations[img_id]ifself.target_transformisnotNone:target=self.target_transform(target)returnimg,target
def__len__(self)->int:returnlen(self.ids)
[docs]classFlickr30k(VisionDataset):"""`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. ann_file (string): Path to annotation file. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """def__init__(self,root:str,ann_file:str,transform:Optional[Callable]=None,target_transform:Optional[Callable]=None,)->None:super().__init__(root,transform=transform,target_transform=target_transform)self.ann_file=os.path.expanduser(ann_file)# Read annotations and store in a dictself.annotations=defaultdict(list)withopen(self.ann_file)asfh:forlineinfh:img_id,caption=line.strip().split("\t")self.annotations[img_id[:-2]].append(caption)self.ids=list(sorted(self.annotations.keys()))
[docs]def__getitem__(self,index:int)->Tuple[Any,Any]:""" Args: index (int): Index Returns: tuple: Tuple (image, target). target is a list of captions for the image. """img_id=self.ids[index]# Imagefilename=os.path.join(self.root,img_id)img=Image.open(filename).convert("RGB")ifself.transformisnotNone:img=self.transform(img)# Captionstarget=self.annotations[img_id]ifself.target_transformisnotNone:target=self.target_transform(target)returnimg,target
def__len__(self)->int:returnlen(self.ids)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.