Shortcuts

Source code for ts.protocol.otf_message_handler

"""
OTF Codec
"""

import io
import json
import logging
import os
import struct
import sys
import time
from builtins import bytearray, bytes

import torch

from ts.utils.util import deprecated

bool_size = 1
int_size = 4
END_OF_LIST = -1
LOAD_MSG = b"L"
PREDICT_MSG = b"I"
RESPONSE = 3


[docs]def retrieve_msg(conn): """ Retrieve a message from the socket channel. :param conn: :return: """ cmd = _retrieve_buffer(conn, 1) if cmd == LOAD_MSG: msg = _retrieve_load_msg(conn) elif cmd == PREDICT_MSG: msg = _retrieve_inference_msg(conn) logging.info("Backend received inference at: %d", time.time()) else: raise ValueError("Invalid command: {}".format(cmd)) return cmd, msg
[docs]def encode_response_headers(resp_hdr_map): msg = bytearray() msg += struct.pack("!i", len(resp_hdr_map)) for k, v in resp_hdr_map.items(): msg += struct.pack("!i", len(k.encode("utf-8"))) msg += k.encode("utf-8") msg += struct.pack("!i", len(v.encode("utf-8"))) msg += v.encode("utf-8") return msg
[docs]def create_predict_response( ret, req_id_map, message, code, context=None, ts_stream_next=False ): """ Create inference response. :param context: :param ret: :param req_id_map: :param message: :param code: :return: """ if str(os.getenv("LOCAL_RANK", 0)) != "0": return None msg = bytearray() msg += struct.pack("!i", code) buf = message.encode("utf-8") msg += struct.pack("!i", len(buf)) msg += buf for idx in req_id_map: req_id = req_id_map.get(idx).encode("utf-8") msg += struct.pack("!i", len(req_id)) msg += req_id if context is None: # Encoding Content-Type msg += struct.pack("!i", 0) # content_type # Encoding the per prediction HTTP response code # status code and reason phrase set to none msg += struct.pack("!i", code) msg += struct.pack("!i", 0) # No code phrase is returned # Response headers none msg += struct.pack("!i", 0) else: if ts_stream_next is True: context.set_response_header(idx, "ts_stream_next", "true") elif context.stopping_criteria: is_stop = context.stopping_criteria[idx](ret[idx]) if is_stop is not None: ts_stream_next = "false" if is_stop else "true" context.set_response_header(idx, "ts_stream_next", ts_stream_next) elif "true" == context.get_response_headers(idx).get("ts_stream_next"): context.set_response_header(idx, "ts_stream_next", "false") content_type = context.get_response_content_type(idx) if content_type is None or len(content_type) == 0: msg += struct.pack("!i", 0) # content_type else: msg += struct.pack("!i", len(content_type)) msg += content_type.encode("utf-8") sc, phrase = context.get_response_status(idx) http_code = sc if sc is not None else 200 http_phrase = phrase if phrase is not None else "" msg += struct.pack("!i", http_code) msg += struct.pack("!i", len(http_phrase)) msg += http_phrase.encode("utf-8") # Response headers msg += encode_response_headers(context.get_response_headers(idx)) if ret is None: buf = b"error" msg += struct.pack("!i", len(buf)) msg += buf else: val = ret[idx] # NOTE: Process bytes/bytearray case before processing the string case. if isinstance(val, (bytes, bytearray)): msg += struct.pack("!i", len(val)) msg += val elif isinstance(val, str): buf = val.encode("utf-8") msg += struct.pack("!i", len(buf)) msg += buf elif isinstance(val, torch.Tensor): buff = io.BytesIO() torch.save(val, buff) buff.seek(0) val_bytes = buff.read() msg += struct.pack("!i", len(val_bytes)) msg += val_bytes else: try: json_value = json.dumps(val, indent=2).encode("utf-8") msg += struct.pack("!i", len(json_value)) msg += json_value except TypeError: logging.warning("Unable to serialize model output.", exc_info=True) return create_predict_response( None, req_id_map, "Unsupported model output data type.", 503 ) msg += struct.pack("!i", -1) # End of list return msg
[docs]def create_load_model_response(code, message): """ Create load model response. :param code: :param message: :return: """ msg = bytearray() msg += struct.pack("!i", code) buf = message.encode("utf-8") msg += struct.pack("!i", len(buf)) msg += buf msg += struct.pack("!i", -1) # no predictions return msg
def _retrieve_buffer(conn, length): data = bytearray() while length > 0: pkt = conn.recv(length) if len(pkt) == 0: logging.info("Frontend disconnected.") sys.exit(0) data += pkt length -= len(pkt) return data def _retrieve_int(conn): data = _retrieve_buffer(conn, int_size) return struct.unpack("!i", data)[0] def _retrieve_bool(conn): data = _retrieve_buffer(conn, bool_size) return struct.unpack("!?", data)[0] def _retrieve_load_msg(conn): """ MSG Frame Format: | cmd value | | int model-name length | model-name value | | int model-path length | model-path value | | int batch-size length | | int handler length | handler value | | int gpu id | | bool limitMaxImagePixels | :param conn: :return: """ msg = {} length = _retrieve_int(conn) msg["modelName"] = _retrieve_buffer(conn, length) length = _retrieve_int(conn) msg["modelPath"] = _retrieve_buffer(conn, length) msg["batchSize"] = _retrieve_int(conn) length = _retrieve_int(conn) msg["handler"] = _retrieve_buffer(conn, length) gpu_id = _retrieve_int(conn) if gpu_id >= 0: msg["gpu"] = gpu_id length = _retrieve_int(conn) msg["envelope"] = _retrieve_buffer(conn, length) msg["limitMaxImagePixels"] = _retrieve_bool(conn) return msg def _retrieve_inference_msg(conn): """ MSG Frame Format: | cmd value | | batch: list of requests | """ msg = [] while True: request = _retrieve_request(conn) if request is None: break msg.append(request) return msg def _retrieve_request(conn): """ MSG Frame Format: | request_id | | request_headers: list of request headers| | parameters: list of request parameters | """ length = _retrieve_int(conn) if length == -1: return None request = {} request["requestId"] = _retrieve_buffer(conn, length) headers = [] while True: header = _retrieve_reqest_header(conn) if header is None: break headers.append(header) request["headers"] = headers model_inputs = [] while True: input_data = _retrieve_input_data(conn) if input_data is None: break model_inputs.append(input_data) request["parameters"] = model_inputs return request def _retrieve_reqest_header(conn): """ MSG Frame Format: | parameter_name | | content_type | | input data in bytes | """ length = _retrieve_int(conn) if length == -1: return None header = {} header["name"] = _retrieve_buffer(conn, length) length = _retrieve_int(conn) header["value"] = _retrieve_buffer(conn, length) return header def _retrieve_input_data(conn): """ MSG Frame Format: | parameter_name | | content_type | | input data in bytes | """ decode_req = os.environ.get("TS_DECODE_INPUT_REQUEST") length = _retrieve_int(conn) if length == -1: return None model_input = {} model_input["name"] = _retrieve_buffer(conn, length).decode("utf-8") length = _retrieve_int(conn) content_type = _retrieve_buffer(conn, length).decode("utf-8") model_input["contentType"] = content_type length = _retrieve_int(conn) value = _retrieve_buffer(conn, length) if content_type == "application/json" and ( decode_req is None or decode_req == "true" ): try: model_input["value"] = json.loads(value.decode("utf-8")) except Exception as e: model_input["value"] = value logging.warning( "Failed json decoding of input data. Forwarding encoded payload", exc_info=True, ) elif content_type.startswith("text") and ( decode_req is None or decode_req == "true" ): try: model_input["value"] = value.decode("utf-8") except Exception as e: model_input["value"] = value logging.warning( "Failed utf-8 decoding of input data. Forwarding encoded payload", exc_info=True, ) else: model_input["value"] = value return model_input
[docs]@deprecated( version=1.0, replacement="ts.handler_utils.utils.send_intermediate_predict_response", ) def send_intermediate_predict_response(ret, req_id_map, message, code, context=None): if str(os.getenv("LOCAL_RANK", 0)) != "0": return None msg = create_predict_response(ret, req_id_map, message, code, context, True) context.cl_socket.sendall(msg)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources