Source code for torchvision.datasets._optical_flow
importitertoolsimportosfromabcimportABC,abstractmethodfromglobimportglobfrompathlibimportPathfromtypingimportCallable,List,Optional,Tuple,UnionimportnumpyasnpimporttorchfromPILimportImagefrom..io.imageimportdecode_png,read_filefrom.utilsimport_read_pfm,verify_str_argfrom.visionimportVisionDatasetT1=Tuple[Image.Image,Image.Image,Optional[np.ndarray],Optional[np.ndarray]]T2=Tuple[Image.Image,Image.Image,Optional[np.ndarray]]__all__=("KittiFlow","Sintel","FlyingThings3D","FlyingChairs","HD1K",)classFlowDataset(ABC,VisionDataset):# Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid# For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be._has_builtin_flow_mask=Falsedef__init__(self,root:Union[str,Path],transforms:Optional[Callable]=None)->None:super().__init__(root=root)self.transforms=transformsself._flow_list:List[str]=[]self._image_list:List[List[str]]=[]def_read_img(self,file_name:str)->Image.Image:img=Image.open(file_name)ifimg.mode!="RGB":img=img.convert("RGB")# type: ignore[assignment]returnimg@abstractmethoddef_read_flow(self,file_name:str):# Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is Truepassdef__getitem__(self,index:int)->Union[T1,T2]:img1=self._read_img(self._image_list[index][0])img2=self._read_img(self._image_list[index][1])ifself._flow_list:# it will be empty for some dataset when split="test"flow=self._read_flow(self._flow_list[index])ifself._has_builtin_flow_mask:flow,valid_flow_mask=flowelse:valid_flow_mask=Noneelse:flow=valid_flow_mask=Noneifself.transformsisnotNone:img1,img2,flow,valid_flow_mask=self.transforms(img1,img2,flow,valid_flow_mask)ifself._has_builtin_flow_maskorvalid_flow_maskisnotNone:# The `or valid_flow_mask is not None` part is here because the mask can be generated within a transformreturnimg1,img2,flow,valid_flow_maskelse:returnimg1,img2,flowdef__len__(self)->int:returnlen(self._image_list)def__rmul__(self,v:int)->torch.utils.data.ConcatDataset:returntorch.utils.data.ConcatDataset([self]*v)
[docs]classSintel(FlowDataset):"""`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow. The dataset is expected to have the following structure: :: root Sintel testing clean scene_1 scene_2 ... final scene_1 scene_2 ... training clean scene_1 scene_2 ... final scene_1 scene_2 ... flow scene_1 scene_2 ... Args: root (str or ``pathlib.Path``): Root directory of the Sintel Dataset. split (string, optional): The dataset split, either "train" (default) or "test" pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for details on the different passes. transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. ``valid_flow_mask`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """def__init__(self,root:Union[str,Path],split:str="train",pass_name:str="clean",transforms:Optional[Callable]=None,)->None:super().__init__(root=root,transforms=transforms)verify_str_arg(split,"split",valid_values=("train","test"))verify_str_arg(pass_name,"pass_name",valid_values=("clean","final","both"))passes=["clean","final"]ifpass_name=="both"else[pass_name]root=Path(root)/"Sintel"flow_root=root/"training"/"flow"forpass_nameinpasses:split_dir="training"ifsplit=="train"elsesplitimage_root=root/split_dir/pass_nameforsceneinos.listdir(image_root):image_list=sorted(glob(str(image_root/scene/"*.png")))foriinrange(len(image_list)-1):self._image_list+=[[image_list[i],image_list[i+1]]]ifsplit=="train":self._flow_list+=sorted(glob(str(flow_root/scene/"*.flo")))
[docs]def__getitem__(self,index:int)->Union[T1,T2]:"""Return example at given index. Args: index(int): The index of the example to retrieve Returns: tuple: A 3-tuple with ``(img1, img2, flow)``. The flow is a numpy array of shape (2, H, W) and the images are PIL images. ``flow`` is None if ``split="test"``. If a valid flow mask is generated within the ``transforms`` parameter, a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned. """returnsuper().__getitem__(index)
[docs]classKittiFlow(FlowDataset):"""`KITTI <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow>`__ dataset for optical flow (2015). The dataset is expected to have the following structure: :: root KittiFlow testing image_2 training image_2 flow_occ Args: root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset. split (string, optional): The dataset split, either "train" (default) or "test" transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. """_has_builtin_flow_mask=Truedef__init__(self,root:Union[str,Path],split:str="train",transforms:Optional[Callable]=None)->None:super().__init__(root=root,transforms=transforms)verify_str_arg(split,"split",valid_values=("train","test"))root=Path(root)/"KittiFlow"/(split+"ing")images1=sorted(glob(str(root/"image_2"/"*_10.png")))images2=sorted(glob(str(root/"image_2"/"*_11.png")))ifnotimages1ornotimages2:raiseFileNotFoundError("Could not find the Kitti flow images. Please make sure the directory structure is correct.")forimg1,img2inzip(images1,images2):self._image_list+=[[img1,img2]]ifsplit=="train":self._flow_list=sorted(glob(str(root/"flow_occ"/"*_10.png")))
[docs]def__getitem__(self,index:int)->Union[T1,T2]:"""Return example at given index. Args: index(int): The index of the example to retrieve Returns: tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W) indicating which flow values are valid. The flow is a numpy array of shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if ``split="test"``. """returnsuper().__getitem__(index)
[docs]classFlyingChairs(FlowDataset):"""`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow. You will also need to download the FlyingChairs_train_val.txt file from the dataset page. The dataset is expected to have the following structure: :: root FlyingChairs data 00001_flow.flo 00001_img1.ppm 00001_img2.ppm ... FlyingChairs_train_val.txt Args: root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset. split (string, optional): The dataset split, either "train" (default) or "val" transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. ``valid_flow_mask`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """def__init__(self,root:Union[str,Path],split:str="train",transforms:Optional[Callable]=None)->None:super().__init__(root=root,transforms=transforms)verify_str_arg(split,"split",valid_values=("train","val"))root=Path(root)/"FlyingChairs"images=sorted(glob(str(root/"data"/"*.ppm")))flows=sorted(glob(str(root/"data"/"*.flo")))split_file_name="FlyingChairs_train_val.txt"ifnotos.path.exists(root/split_file_name):raiseFileNotFoundError("The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring).")split_list=np.loadtxt(str(root/split_file_name),dtype=np.int32)foriinrange(len(flows)):split_id=split_list[i]if(split=="train"andsplit_id==1)or(split=="val"andsplit_id==2):self._flow_list+=[flows[i]]self._image_list+=[[images[2*i],images[2*i+1]]]
[docs]def__getitem__(self,index:int)->Union[T1,T2]:"""Return example at given index. Args: index(int): The index of the example to retrieve Returns: tuple: A 3-tuple with ``(img1, img2, flow)``. The flow is a numpy array of shape (2, H, W) and the images are PIL images. ``flow`` is None if ``split="val"``. If a valid flow mask is generated within the ``transforms`` parameter, a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned. """returnsuper().__getitem__(index)
[docs]classFlyingThings3D(FlowDataset):"""`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow. The dataset is expected to have the following structure: :: root FlyingThings3D frames_cleanpass TEST TRAIN frames_finalpass TEST TRAIN optical_flow TEST TRAIN Args: root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset. split (string, optional): The dataset split, either "train" (default) or "test" pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for details on the different passes. camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both". transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. ``valid_flow_mask`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """def__init__(self,root:Union[str,Path],split:str="train",pass_name:str="clean",camera:str="left",transforms:Optional[Callable]=None,)->None:super().__init__(root=root,transforms=transforms)verify_str_arg(split,"split",valid_values=("train","test"))split=split.upper()verify_str_arg(pass_name,"pass_name",valid_values=("clean","final","both"))passes={"clean":["frames_cleanpass"],"final":["frames_finalpass"],"both":["frames_cleanpass","frames_finalpass"],}[pass_name]verify_str_arg(camera,"camera",valid_values=("left","right","both"))cameras=["left","right"]ifcamera=="both"else[camera]root=Path(root)/"FlyingThings3D"directions=("into_future","into_past")forpass_name,camera,directioninitertools.product(passes,cameras,directions):image_dirs=sorted(glob(str(root/pass_name/split/"*/*")))image_dirs=sorted(Path(image_dir)/cameraforimage_dirinimage_dirs)flow_dirs=sorted(glob(str(root/"optical_flow"/split/"*/*")))flow_dirs=sorted(Path(flow_dir)/direction/cameraforflow_dirinflow_dirs)ifnotimage_dirsornotflow_dirs:raiseFileNotFoundError("Could not find the FlyingThings3D flow images. ""Please make sure the directory structure is correct.")forimage_dir,flow_dirinzip(image_dirs,flow_dirs):images=sorted(glob(str(image_dir/"*.png")))flows=sorted(glob(str(flow_dir/"*.pfm")))foriinrange(len(flows)-1):ifdirection=="into_future":self._image_list+=[[images[i],images[i+1]]]self._flow_list+=[flows[i]]elifdirection=="into_past":self._image_list+=[[images[i+1],images[i]]]self._flow_list+=[flows[i+1]]
[docs]def__getitem__(self,index:int)->Union[T1,T2]:"""Return example at given index. Args: index(int): The index of the example to retrieve Returns: tuple: A 3-tuple with ``(img1, img2, flow)``. The flow is a numpy array of shape (2, H, W) and the images are PIL images. ``flow`` is None if ``split="test"``. If a valid flow mask is generated within the ``transforms`` parameter, a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned. """returnsuper().__getitem__(index)
[docs]classHD1K(FlowDataset):"""`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow. The dataset is expected to have the following structure: :: root hd1k hd1k_challenge image_2 hd1k_flow_gt flow_occ hd1k_input image_2 Args: root (str or ``pathlib.Path``): Root directory of the HD1K Dataset. split (string, optional): The dataset split, either "train" (default) or "test" transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. """_has_builtin_flow_mask=Truedef__init__(self,root:Union[str,Path],split:str="train",transforms:Optional[Callable]=None)->None:super().__init__(root=root,transforms=transforms)verify_str_arg(split,"split",valid_values=("train","test"))root=Path(root)/"hd1k"ifsplit=="train":# There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loopforseq_idxinrange(36):flows=sorted(glob(str(root/"hd1k_flow_gt"/"flow_occ"/f"{seq_idx:06d}_*.png")))images=sorted(glob(str(root/"hd1k_input"/"image_2"/f"{seq_idx:06d}_*.png")))foriinrange(len(flows)-1):self._flow_list+=[flows[i]]self._image_list+=[[images[i],images[i+1]]]else:images1=sorted(glob(str(root/"hd1k_challenge"/"image_2"/"*10.png")))images2=sorted(glob(str(root/"hd1k_challenge"/"image_2"/"*11.png")))forimage1,image2inzip(images1,images2):self._image_list+=[[image1,image2]]ifnotself._image_list:raiseFileNotFoundError("Could not find the HD1K images. Please make sure the directory structure is correct.")def_read_flow(self,file_name:str)->Tuple[np.ndarray,np.ndarray]:return_read_16bits_png_with_flow_and_valid_mask(file_name)
[docs]def__getitem__(self,index:int)->Union[T1,T2]:"""Return example at given index. Args: index(int): The index of the example to retrieve Returns: tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W) indicating which flow values are valid. The flow is a numpy array of shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if ``split="test"``. """returnsuper().__getitem__(index)
def_read_flo(file_name:str)->np.ndarray:"""Read .flo file in Middlebury format"""# Code adapted from:# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy# Everything needs to be in little Endian according to# https://vision.middlebury.edu/flow/code/flow-code/README.txtwithopen(file_name,"rb")asf:magic=np.fromfile(f,"c",count=4).tobytes()ifmagic!=b"PIEH":raiseValueError("Magic number incorrect. Invalid .flo file")w=int(np.fromfile(f,"<i4",count=1))h=int(np.fromfile(f,"<i4",count=1))data=np.fromfile(f,"<f4",count=2*w*h)returndata.reshape(h,w,2).transpose(2,0,1)def_read_16bits_png_with_flow_and_valid_mask(file_name:str)->Tuple[np.ndarray,np.ndarray]:flow_and_valid=decode_png(read_file(file_name)).to(torch.float32)flow,valid_flow_mask=flow_and_valid[:2,:,:],flow_and_valid[2,:,:]flow=(flow-2**15)/64# This conversion is explained somewhere on the kitti archivevalid_flow_mask=valid_flow_mask.bool()# For consistency with other datasets, we convert to numpyreturnflow.numpy(),valid_flow_mask.numpy()
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.