Shortcuts

Source code for ts.torch_handler.text_classifier

# pylint: disable=E1102
# TODO remove pylint disable comment after https://github.com/pytorch/pytorch/issues/24807 gets merged.
"""
Module for text classification default handler
DOES NOT SUPPORT BATCH!
"""
import logging

import torch
import torch.nn.functional as F
from captum.attr import TokenReferenceBase
from torchtext.data.utils import ngrams_iterator

from ..utils.util import map_class_to_label
from .text_handler import TextHandler

logger = logging.getLogger(__name__)


[docs]class TextClassifier(TextHandler): """ TextClassifier handler class. This handler takes a text (string) and as input and returns the classification text based on the model vocabulary. """ ngrams = 2
[docs] def preprocess(self, data): """Normalizes the input text for PyTorch model using following basic cleanup operations : - remove html tags - lowercase all text - expand contractions [like I'd -> I would, don't -> do not] - remove accented characters - remove punctuations Converts the normalized text to tensor using the source_vocab. Args: data (str): The input data is in the form of a string Returns: (Tensor): Text Tensor is returned after perfoming the pre-processing operations (str): The raw input is also returned in this function """ # Compat layer: normally the envelope should just return the data # directly, but older versions of Torchserve didn't have envelope. # Processing only the first input, not handling batch inference line = data[0] text = line.get("data") or line.get("body") # Decode text if not a str but bytes or bytearray if isinstance(text, (bytes, bytearray)): text = text.decode("utf-8") text = self._remove_html_tags(text) text = text.lower() text = self._expand_contractions(text) text = self._remove_accented_characters(text) text = self._remove_punctuation(text) text = self._tokenize(text) text_tensor = torch.as_tensor( [self.source_vocab[token] for token in ngrams_iterator(text, self.ngrams)], device=self.device, ) return text_tensor, text
[docs] def inference(self, data, *args, **kwargs): """The Inference Request is made through this function and the user needs to override the inference function to customize it. Args: data (torch tensor): The data is in the form of Torch Tensor whose shape should match that of the Model Input shape. Returns: (Torch Tensor): The predicted response from the model is returned in this function. """ text_tensor, _ = data offsets = torch.as_tensor([0], device=self.device) return super().inference(text_tensor, offsets)
[docs] def postprocess(self, data): """ The post process function converts the prediction response into a Torchserve compatible format Args: data (Torch Tensor): The data parameter comes from the prediction output output_explain (None): Defaults to None. Returns: (list): Returns the response containing the predictions and explanations (if the Endpoint is hit).It takes the form of a list of dictionary. """ data = F.softmax(data) data = data.tolist() return map_class_to_label(data, self.mapping)
[docs] def get_insights(self, text_preprocess, _, target=0): """Calculates the captum insights Args: text_preprocess (tensor): Tensor of the Text Input _ (str): The Raw text data specified in the input request target (int): Defaults to 0, the user needs to specify the target for the captum explanation. Returns: (dict): Returns a dictionary of the word token importances """ text_tensor, all_tokens = text_preprocess token_reference = TokenReferenceBase() logger.info("input_text shape %s", len(text_tensor.shape)) logger.info("get_insights target %s", target) offsets = torch.tensor([0]).to(self.device) all_tokens = self.get_word_token(all_tokens) logger.info("text_tensor tokenized shape %s", text_tensor.shape) reference_indices = token_reference.generate_reference( text_tensor.shape[0], device=self.device ).squeeze(0) logger.info("reference indices shape %s", reference_indices.shape) # all_tokens = self.get_word_token(text) attributions = self.lig.attribute( text_tensor, reference_indices, additional_forward_args=(offsets), return_convergence_delta=False, target=target, ) logger.info("attributions shape %s", attributions.shape) attributions_sum = self.summarize_attributions(attributions) response = {} response["importances"] = attributions_sum.tolist() response["words"] = all_tokens return [response]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources