
Source code for ts.torch_handler.image_segmenter

Module for image segmentation default handler
import torch
from torchvision import transforms as T
from .vision_handler import VisionHandler

[docs]class ImageSegmenter(VisionHandler): """ ImageSegmenter handler class. This handler takes a batch of images and returns output shape as [N K H W], where N - batch size, K - number of classes, H - height and W - width. """ image_processing = T.Compose([ T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )])
[docs] def postprocess(self, data): # Returning the class for every pixel makes the response size too big # (> 24mb). Instead, we'll only return the top class for each image data = data['out'] data = torch.nn.functional.softmax(data, dim=1) data = torch.max(data, dim=1) data = torch.stack([data.indices.type(data.values.dtype), data.values], dim=3) return data.tolist()


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources