Source code for torchvision.datasets.widerface
import os
from os.path import abspath, expanduser
from typing import Any, Callable, List, Dict, Optional, Tuple, Union
import torch
from PIL import Image
from .utils import (
download_file_from_google_drive,
download_and_extract_archive,
extract_archive,
verify_str_arg,
)
from .vision import VisionDataset
[docs]class WIDERFace(VisionDataset):
"""`WIDERFace <http://shuoyang1213.me/WIDERFACE/>`_ Dataset.
Args:
root (string): Root directory where images and annotations are downloaded to.
Expects the following folder structure if download=False:
.. code::
<root>
└── widerface
├── wider_face_split ('wider_face_split.zip' if compressed)
├── WIDER_train ('WIDER_train.zip' if compressed)
├── WIDER_val ('WIDER_val.zip' if compressed)
└── WIDER_test ('WIDER_test.zip' if compressed)
split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
Defaults to ``train``.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
BASE_FOLDER = "widerface"
FILE_LIST = [
# File ID MD5 Hash Filename
("15hGDLhsx8bLgLcIRD5DhYt5iBxnjNF1M", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
("1GUCogbp16PMGa39thoMMeWxp7Rp5oM8Q", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
("1HIfDbVEWKmsYKJZm4lchTBDLW5N7dY5T", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
]
ANNOTATIONS_FILE = (
"http://shuoyang1213.me/WIDERFACE/support/bbx_annotation/wider_face_split.zip",
"0e3767bcf0e326556d407bf5bff5d27c",
"wider_face_split.zip",
)
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(
root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
)
# check arguments
self.split = verify_str_arg(split, "split", ("train", "val", "test"))
if download:
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download and prepare it")
self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
if self.split in ("train", "val"):
self.parse_train_val_annotations_file()
else:
self.parse_test_annotations_file()
[docs] def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a dict of annotations for all faces in the image.
target=None for the test split.
"""
# stay consistent with other datasets and return a PIL Image
img = Image.open(self.img_info[index]["img_path"])
if self.transform is not None:
img = self.transform(img)
target = None if self.split == "test" else self.img_info[index]["annotations"]
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.img_info)
def extra_repr(self) -> str:
lines = ["Split: {split}"]
return "\n".join(lines).format(**self.__dict__)
def parse_train_val_annotations_file(self) -> None:
filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
filepath = os.path.join(self.root, "wider_face_split", filename)
with open(filepath) as f:
lines = f.readlines()
file_name_line, num_boxes_line, box_annotation_line = True, False, False
num_boxes, box_counter = 0, 0
labels = []
for line in lines:
line = line.rstrip()
if file_name_line:
img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
img_path = abspath(expanduser(img_path))
file_name_line = False
num_boxes_line = True
elif num_boxes_line:
num_boxes = int(line)
num_boxes_line = False
box_annotation_line = True
elif box_annotation_line:
box_counter += 1
line_split = line.split(" ")
line_values = [int(x) for x in line_split]
labels.append(line_values)
if box_counter >= num_boxes:
box_annotation_line = False
file_name_line = True
labels_tensor = torch.tensor(labels)
self.img_info.append(
{
"img_path": img_path,
"annotations": {
"bbox": labels_tensor[:, 0:4], # x, y, width, height
"blur": labels_tensor[:, 4],
"expression": labels_tensor[:, 5],
"illumination": labels_tensor[:, 6],
"occlusion": labels_tensor[:, 7],
"pose": labels_tensor[:, 8],
"invalid": labels_tensor[:, 9],
},
}
)
box_counter = 0
labels.clear()
else:
raise RuntimeError(f"Error parsing annotation file {filepath}")
def parse_test_annotations_file(self) -> None:
filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
filepath = abspath(expanduser(filepath))
with open(filepath) as f:
lines = f.readlines()
for line in lines:
line = line.rstrip()
img_path = os.path.join(self.root, "WIDER_test", "images", line)
img_path = abspath(expanduser(img_path))
self.img_info.append({"img_path": img_path})
def _check_integrity(self) -> bool:
# Allow original archive to be deleted (zip). Only need the extracted images
all_files = self.FILE_LIST.copy()
all_files.append(self.ANNOTATIONS_FILE)
for (_, md5, filename) in all_files:
file, ext = os.path.splitext(filename)
extracted_dir = os.path.join(self.root, file)
if not os.path.exists(extracted_dir):
return False
return True
def download(self) -> None:
if self._check_integrity():
print("Files already downloaded and verified")
return
# download and extract image data
for (file_id, md5, filename) in self.FILE_LIST:
download_file_from_google_drive(file_id, self.root, filename, md5)
filepath = os.path.join(self.root, filename)
extract_archive(filepath)
# download and extract annotation files
download_and_extract_archive(
url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
)