[docs]defwrap_dataset_for_transforms_v2(dataset):"""[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`. .. v2betastatus:: wrap_dataset_for_transforms_v2 function Example: >>> dataset = torchvision.datasets.CocoDetection(...) >>> dataset = wrap_dataset_for_transforms_v2(dataset) .. note:: For now, only the most popular datasets are supported. Furthermore, the wrapper only supports dataset configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you to raise an issue to ``torchvision`` for a dataset or configuration that you need, please do so. The dataset samples are wrapped according to the description below. Special cases: * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format), ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are preserved. * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are preserved. * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY`` coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dictsthe wrapper returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are preserved. * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a :class:`~torchvision.datapoints.Mask` datapoint. * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a :class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and ``"labels"``. * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY`` coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. Image classification datasets This wrapper is a no-op for image classification datasets, since they were already fully supported by :mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`. Segmentation datasets Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item). Video classification datasets Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple containing a :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a :class:`~torchvision.datapoints.Video` while leaving the other items as is. .. note:: Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative ``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`. Args: dataset: the dataset instance to wrap for compatibility with transforms v2. """returnVisionDatasetDatapointWrapper(dataset)
classWrapperFactories(dict):defregister(self,dataset_cls):defdecorator(wrapper_factory):self[dataset_cls]=wrapper_factoryreturnwrapper_factoryreturndecorator# We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the# dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can# provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when# we have access to the dataset instance.WRAPPER_FACTORIES=WrapperFactories()classVisionDatasetDatapointWrapper(Dataset):def__init__(self,dataset):dataset_cls=type(dataset)ifnotisinstance(dataset,datasets.VisionDataset):raiseTypeError(f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "f"but got a '{dataset_cls.__name__}' instead.")forclsindataset_cls.mro():ifclsinWRAPPER_FACTORIES:wrapper_factory=WRAPPER_FACTORIES[cls]breakelifclsisdatasets.VisionDataset:# TODO: If we have documentation on how to do that, put a link in the error message.msg=f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."ifdataset_clsindatasets.__dict__.values():msg=(f"{msg} If an automated wrapper for this dataset would be useful for you, "f"please open an issue at https://github.com/pytorch/vision/issues.")raiseTypeError(msg)self._dataset=datasetself._wrapper=wrapper_factory(dataset)# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.# Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint# `transforms`# https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54# some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to# disable all three here to be able to extract the untransformed sample to wrap.self.transform,dataset.transform=dataset.transform,Noneself.target_transform,dataset.target_transform=dataset.target_transform,Noneself.transforms,dataset.transforms=dataset.transforms,Nonedef__getattr__(self,item):withcontextlib.suppress(AttributeError):returnobject.__getattribute__(self,item)returngetattr(self._dataset,item)def__getitem__(self,idx):# This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor# of this classsample=self._dataset[idx]sample=self._wrapper(idx,sample)# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)# or joint (`transforms`), we can access the full functionality through `transforms`ifself.transformsisnotNone:sample=self.transforms(*sample)returnsampledef__len__(self):returnlen(self._dataset)defraise_not_supported(description):raiseRuntimeError(f"{description} is currently not supported by this wrapper. "f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues.")defidentity(item):returnitemdefidentity_wrapper_factory(dataset):defwrapper(idx,sample):returnsamplereturnwrapperdefpil_image_to_mask(pil_image):returndatapoints.Mask(pil_image)deflist_of_dicts_to_dict_of_lists(list_of_dicts):dict_of_lists=defaultdict(list)fordctinlist_of_dicts:forkey,valueindct.items():dict_of_lists[key].append(value)returndict(dict_of_lists)defwrap_target_by_type(target,*,target_types,type_wrappers):ifnotisinstance(target,(tuple,list)):target=[target]wrapped_target=tuple(type_wrappers.get(target_type,identity)(item)fortarget_type,iteminzip(target_types,target))iflen(wrapped_target)==1:wrapped_target=wrapped_target[0]returnwrapped_targetdefclassification_wrapper_factory(dataset):returnidentity_wrapper_factory(dataset)fordataset_clsin[datasets.Caltech256,datasets.CIFAR10,datasets.CIFAR100,datasets.ImageNet,datasets.MNIST,datasets.FashionMNIST,datasets.GTSRB,datasets.DatasetFolder,datasets.ImageFolder,]:WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)defsegmentation_wrapper_factory(dataset):defwrapper(idx,sample):image,mask=samplereturnimage,pil_image_to_mask(mask)returnwrapperfordataset_clsin[datasets.VOCSegmentation,]:WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory)defvideo_classification_wrapper_factory(dataset):ifdataset.video_clips.output_format=="THWC":raiseRuntimeError(f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, "f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead.")defwrapper(idx,sample):video,audio,label=samplevideo=datapoints.Video(video)returnvideo,audio,labelreturnwrapperfordataset_clsin[datasets.HMDB51,datasets.Kinetics,datasets.UCF101,]:WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory)@WRAPPER_FACTORIES.register(datasets.Caltech101)defcaltech101_wrapper_factory(dataset):if"annotation"indataset.target_type:raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")returnclassification_wrapper_factory(dataset)@WRAPPER_FACTORIES.register(datasets.CocoDetection)defcoco_dectection_wrapper_factory(dataset):defsegmentation_to_mask(segmentation,*,spatial_size):frompycocotoolsimportmasksegmentation=(mask.frPyObjects(segmentation,*spatial_size)ifisinstance(segmentation,dict)elsemask.merge(mask.frPyObjects(segmentation,*spatial_size)))returntorch.from_numpy(mask.decode(segmentation))defwrapper(idx,sample):image_id=dataset.ids[idx]image,target=sampleifnottarget:returnimage,dict(image_id=image_id)batched_target=list_of_dicts_to_dict_of_lists(target)batched_target["image_id"]=image_idspatial_size=tuple(F.get_spatial_size(image))batched_target["boxes"]=F.convert_format_bounding_box(datapoints.BoundingBox(batched_target["bbox"],format=datapoints.BoundingBoxFormat.XYWH,spatial_size=spatial_size,),new_format=datapoints.BoundingBoxFormat.XYXY,)batched_target["masks"]=datapoints.Mask(torch.stack([segmentation_to_mask(segmentation,spatial_size=spatial_size)forsegmentationinbatched_target["segmentation"]]),)batched_target["labels"]=torch.tensor(batched_target["category_id"])returnimage,batched_targetreturnwrapperWRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)VOC_DETECTION_CATEGORIES=["__background__","aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair","cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep","sofa","train","tvmonitor",]VOC_DETECTION_CATEGORY_TO_IDX=dict(zip(VOC_DETECTION_CATEGORIES,range(len(VOC_DETECTION_CATEGORIES))))@WRAPPER_FACTORIES.register(datasets.VOCDetection)defvoc_detection_wrapper_factory(dataset):defwrapper(idx,sample):image,target=samplebatched_instances=list_of_dicts_to_dict_of_lists(target["annotation"]["object"])target["boxes"]=datapoints.BoundingBox([[int(bndbox[part])forpartin("xmin","ymin","xmax","ymax")]forbndboxinbatched_instances["bndbox"]],format=datapoints.BoundingBoxFormat.XYXY,spatial_size=(image.height,image.width),)target["labels"]=torch.tensor([VOC_DETECTION_CATEGORY_TO_IDX[category]forcategoryinbatched_instances["name"]])returnimage,targetreturnwrapper@WRAPPER_FACTORIES.register(datasets.SBDataset)defsbd_wrapper(dataset):ifdataset.mode=="boundaries":raise_not_supported("SBDataset with mode='boundaries'")returnsegmentation_wrapper_factory(dataset)@WRAPPER_FACTORIES.register(datasets.CelebA)defceleba_wrapper_factory(dataset):ifany(target_typeindataset.target_typefortarget_typein["attr","landmarks"]):raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")defwrapper(idx,sample):image,target=sampletarget=wrap_target_by_type(target,target_types=dataset.target_type,type_wrappers={"bbox":lambdaitem:F.convert_format_bounding_box(datapoints.BoundingBox(item,format=datapoints.BoundingBoxFormat.XYWH,spatial_size=(image.height,image.width),),new_format=datapoints.BoundingBoxFormat.XYXY,),},)returnimage,targetreturnwrapperKITTI_CATEGORIES=["Car","Van","Truck","Pedestrian","Person_sitting","Cyclist","Tram","Misc","DontCare"]KITTI_CATEGORY_TO_IDX=dict(zip(KITTI_CATEGORIES,range(len(KITTI_CATEGORIES))))@WRAPPER_FACTORIES.register(datasets.Kitti)defkitti_wrapper_factory(dataset):defwrapper(idx,sample):image,target=sampleiftargetisnotNone:target=list_of_dicts_to_dict_of_lists(target)target["boxes"]=datapoints.BoundingBox(target["bbox"],format=datapoints.BoundingBoxFormat.XYXY,spatial_size=(image.height,image.width))target["labels"]=torch.tensor([KITTI_CATEGORY_TO_IDX[category]forcategoryintarget["type"]])returnimage,targetreturnwrapper@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)defoxford_iiit_pet_wrapper_factor(dataset):defwrapper(idx,sample):image,target=sampleiftargetisnotNone:target=wrap_target_by_type(target,target_types=dataset._target_types,type_wrappers={"segmentation":pil_image_to_mask,},)returnimage,targetreturnwrapper@WRAPPER_FACTORIES.register(datasets.Cityscapes)defcityscapes_wrapper_factory(dataset):ifany(target_typeindataset.target_typefortarget_typein["polygon","color"]):raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`")definstance_segmentation_wrapper(mask):# See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21data=pil_image_to_mask(mask)masks=[]labels=[]foridindata.unique():masks.append(data==id)label=idiflabel>=1_000:label//=1_000labels.append(label)returndict(masks=datapoints.Mask(torch.stack(masks)),labels=torch.stack(labels))defwrapper(idx,sample):image,target=sampletarget=wrap_target_by_type(target,target_types=dataset.target_type,type_wrappers={"instance":instance_segmentation_wrapper,"semantic":pil_image_to_mask,},)returnimage,targetreturnwrapper@WRAPPER_FACTORIES.register(datasets.WIDERFace)defwiderface_wrapper(dataset):defwrapper(idx,sample):image,target=sampleiftargetisnotNone:target["bbox"]=F.convert_format_bounding_box(datapoints.BoundingBox(target["bbox"],format=datapoints.BoundingBoxFormat.XYWH,spatial_size=(image.height,image.width)),new_format=datapoints.BoundingBoxFormat.XYXY,)returnimage,targetreturnwrapper
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.