Shortcuts

Source code for ts.torch_handler.object_detector

"""
Module for object detection default handler
"""
import torch
from torchvision import transforms
from torchvision import __version__ as torchvision_version
from packaging import version
from .vision_handler import VisionHandler
from ..utils.util import map_class_to_label


[docs]class ObjectDetector(VisionHandler): """ ObjectDetector handler class. This handler takes an image and returns list of detected classes and bounding boxes respectively """ image_processing = transforms.Compose([transforms.ToTensor()]) threshold = 0.5
[docs] def initialize(self, context): super().initialize(context) properties = context.system_properties # Torchvision breaks with object detector models before 0.6.0 if version.parse(torchvision_version) < version.parse("0.6.0"): self.initialized = False self.device = torch.device( "cuda" if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" ) self.model.to(self.device) self.model.eval() self.initialized = True
[docs] def postprocess(self, data): result = [] box_filters = [row["scores"] >= self.threshold for row in data] filtered_boxes, filtered_classes, filtered_scores = [ [ row[key][box_filter].tolist() for row, box_filter in zip(data, box_filters) ] for key in ["boxes", "labels", "scores"] ] for classes, boxes, scores in zip( filtered_classes, filtered_boxes, filtered_scores ): retval = [] for _class, _box, _score in zip(classes, boxes, scores): _retval = map_class_to_label([[_box]], self.mapping, [[_class]])[0] _retval["score"] = _score retval.append(_retval) result.append(retval) return result

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources