Source code for ts.torch_handler.densenet_handler

Module for image classification default handler
import inspect
import logging
import os
import importlib.util
import time
import io
import torch

logger = logging.getLogger(__name__)

[docs]class DenseNetHandler: """ DenseNetHandler handler class. This handler takes an image and returns the name of object in that image. """ def __init__(self): self.model = None self.device = None self.initialized = False self.context = None self.manifest = None self.map_location = None
[docs] def initialize(self, context): """First try to load torchscript else load eager mode state_dict based model""" properties = context.system_properties self.map_location = "cuda" if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" self.device = torch.device( self.map_location + ":" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else self.map_location ) self.manifest = context.manifest model_dir = properties.get("model_dir") serialized_file = self.manifest["model"]["serializedFile"] model_pt_path = os.path.join(model_dir, serialized_file) if not os.path.isfile(model_pt_path): raise RuntimeError("Missing the file") # model def file model_file = self.manifest["model"].get("modelFile", "") if model_file: logger.debug("Loading eager model") self.model = self._load_pickled_model(model_dir, model_file, model_pt_path) else: logger.debug("Loading torchscript model") self.model = self._load_torchscript_model(model_pt_path) self.model.eval() logger.debug("Model file %s loaded successfully", model_pt_path) self.initialized = True
def _load_torchscript_model(self, model_pt_path): return torch.jit.load(model_pt_path, map_location=self.map_location) def _load_pickled_model(self, model_dir, model_file, model_pt_path): model_def_path = os.path.join(model_dir, model_file) if not os.path.isfile(model_def_path): raise RuntimeError("Missing the file") module = importlib.import_module(model_file.split(".")[0]) model_class_definitions = list_classes_from_module(module) if len(model_class_definitions) != 1: raise ValueError( "Expected only one class as model definition. {}".format( model_class_definitions ) ) model_class = model_class_definitions[0] state_dict = torch.load(model_pt_path, map_location=self.map_location) model = model_class() model.load_state_dict(state_dict) return model
[docs] def inference(self, data, *args, **kwargs): """ Override to customize the inference :param data: Torch tensor, matching the model input shape :return: Prediction output as Torch tensor """ marshalled_data = with torch.no_grad(): results = self.model(marshalled_data, *args, **kwargs) return results
[docs] def handle(self, data, context): """ Entry point for default handler """ # It can be used for pre or post processing if needed as additional request # information is available in context start_time = time.time() self.context = context metrics = self.context.metrics values = [] for row in data: image = row.get("data") or row.get("body") tensor = torch.load(io.BytesIO(image)) values.append(tensor) data = self.inference(torch.stack(values)) stop_time = time.time() metrics.add_time( "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" ) return [data]
[docs]def list_classes_from_module(module, parent_class=None): """ Parse user defined module to get all model service classes in it. :param module: :param parent_class: :return: List of model service class definitions """ # Parsing the module to get all defined classes classes = [ cls[1] for cls in inspect.getmembers( module, lambda member: inspect.isclass(member) and member.__module__ == module.__name__, ) ] # filter classes that is subclass of parent_class if parent_class is not None: return [c for c in classes if issubclass(c, parent_class)] return classes


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources