Source code for ts.protocol.otf_message_handler
"""
OTF Codec
"""
import io
import json
import logging
import os
import struct
import sys
import time
from builtins import bytearray, bytes
import torch
from ts.utils.util import deprecated
bool_size = 1
int_size = 4
END_OF_LIST = -1
LOAD_MSG = b"L"
PREDICT_MSG = b"I"
RESPONSE = 3
[docs]def retrieve_msg(conn):
"""
Retrieve a message from the socket channel.
:param conn:
:return:
"""
cmd = _retrieve_buffer(conn, 1)
if cmd == LOAD_MSG:
msg = _retrieve_load_msg(conn)
elif cmd == PREDICT_MSG:
msg = _retrieve_inference_msg(conn)
logging.info("Backend received inference at: %d", time.time())
else:
raise ValueError("Invalid command: {}".format(cmd))
return cmd, msg
[docs]def encode_response_headers(resp_hdr_map):
msg = bytearray()
msg += struct.pack("!i", len(resp_hdr_map))
for k, v in resp_hdr_map.items():
msg += struct.pack("!i", len(k.encode("utf-8")))
msg += k.encode("utf-8")
msg += struct.pack("!i", len(v.encode("utf-8")))
msg += v.encode("utf-8")
return msg
[docs]def create_predict_response(
ret, req_id_map, message, code, context=None, ts_stream_next=False
):
"""
Create inference response.
:param context:
:param ret:
:param req_id_map:
:param message:
:param code:
:return:
"""
if str(os.getenv("LOCAL_RANK", 0)) != "0":
return None
msg = bytearray()
msg += struct.pack("!i", code)
buf = message.encode("utf-8")
msg += struct.pack("!i", len(buf))
msg += buf
for idx in req_id_map:
req_id = req_id_map.get(idx).encode("utf-8")
msg += struct.pack("!i", len(req_id))
msg += req_id
if context is None:
# Encoding Content-Type
msg += struct.pack("!i", 0) # content_type
# Encoding the per prediction HTTP response code
# status code and reason phrase set to none
msg += struct.pack("!i", code)
msg += struct.pack("!i", 0) # No code phrase is returned
# Response headers none
msg += struct.pack("!i", 0)
else:
if ts_stream_next is True:
context.set_response_header(idx, "ts_stream_next", "true")
elif context.stopping_criteria:
is_stop = context.stopping_criteria[idx](ret[idx])
if is_stop is not None:
ts_stream_next = "false" if is_stop else "true"
context.set_response_header(idx, "ts_stream_next", ts_stream_next)
elif "true" == context.get_response_headers(idx).get("ts_stream_next"):
context.set_response_header(idx, "ts_stream_next", "false")
content_type = context.get_response_content_type(idx)
if content_type is None or len(content_type) == 0:
msg += struct.pack("!i", 0) # content_type
else:
msg += struct.pack("!i", len(content_type))
msg += content_type.encode("utf-8")
sc, phrase = context.get_response_status(idx)
http_code = sc if sc is not None else 200
http_phrase = phrase if phrase is not None else ""
msg += struct.pack("!i", http_code)
msg += struct.pack("!i", len(http_phrase))
msg += http_phrase.encode("utf-8")
# Response headers
msg += encode_response_headers(context.get_response_headers(idx))
if ret is None:
buf = b"error"
msg += struct.pack("!i", len(buf))
msg += buf
else:
val = ret[idx]
# NOTE: Process bytes/bytearray case before processing the string case.
if isinstance(val, (bytes, bytearray)):
msg += struct.pack("!i", len(val))
msg += val
elif isinstance(val, str):
buf = val.encode("utf-8")
msg += struct.pack("!i", len(buf))
msg += buf
elif isinstance(val, torch.Tensor):
buff = io.BytesIO()
torch.save(val, buff)
buff.seek(0)
val_bytes = buff.read()
msg += struct.pack("!i", len(val_bytes))
msg += val_bytes
else:
try:
json_value = json.dumps(val, indent=2).encode("utf-8")
msg += struct.pack("!i", len(json_value))
msg += json_value
except TypeError:
logging.warning("Unable to serialize model output.", exc_info=True)
return create_predict_response(
None, req_id_map, "Unsupported model output data type.", 503
)
msg += struct.pack("!i", -1) # End of list
return msg
[docs]def create_load_model_response(code, message):
"""
Create load model response.
:param code:
:param message:
:return:
"""
msg = bytearray()
msg += struct.pack("!i", code)
buf = message.encode("utf-8")
msg += struct.pack("!i", len(buf))
msg += buf
msg += struct.pack("!i", -1) # no predictions
return msg
def _retrieve_buffer(conn, length):
data = bytearray()
while length > 0:
pkt = conn.recv(length)
if len(pkt) == 0:
logging.info("Frontend disconnected.")
sys.exit(0)
data += pkt
length -= len(pkt)
return data
def _retrieve_int(conn):
data = _retrieve_buffer(conn, int_size)
return struct.unpack("!i", data)[0]
def _retrieve_bool(conn):
data = _retrieve_buffer(conn, bool_size)
return struct.unpack("!?", data)[0]
def _retrieve_load_msg(conn):
"""
MSG Frame Format:
| cmd value |
| int model-name length | model-name value |
| int model-path length | model-path value |
| int batch-size length |
| int handler length | handler value |
| int gpu id |
| bool limitMaxImagePixels |
:param conn:
:return:
"""
msg = {}
length = _retrieve_int(conn)
msg["modelName"] = _retrieve_buffer(conn, length)
length = _retrieve_int(conn)
msg["modelPath"] = _retrieve_buffer(conn, length)
msg["batchSize"] = _retrieve_int(conn)
length = _retrieve_int(conn)
msg["handler"] = _retrieve_buffer(conn, length)
gpu_id = _retrieve_int(conn)
if gpu_id >= 0:
msg["gpu"] = gpu_id
length = _retrieve_int(conn)
msg["envelope"] = _retrieve_buffer(conn, length)
msg["limitMaxImagePixels"] = _retrieve_bool(conn)
return msg
def _retrieve_inference_msg(conn):
"""
MSG Frame Format:
| cmd value |
| batch: list of requests |
"""
msg = []
while True:
request = _retrieve_request(conn)
if request is None:
break
msg.append(request)
return msg
def _retrieve_request(conn):
"""
MSG Frame Format:
| request_id |
| request_headers: list of request headers|
| parameters: list of request parameters |
"""
length = _retrieve_int(conn)
if length == -1:
return None
request = {}
request["requestId"] = _retrieve_buffer(conn, length)
headers = []
while True:
header = _retrieve_reqest_header(conn)
if header is None:
break
headers.append(header)
request["headers"] = headers
model_inputs = []
while True:
input_data = _retrieve_input_data(conn)
if input_data is None:
break
model_inputs.append(input_data)
request["parameters"] = model_inputs
return request
def _retrieve_reqest_header(conn):
"""
MSG Frame Format:
| parameter_name |
| content_type |
| input data in bytes |
"""
length = _retrieve_int(conn)
if length == -1:
return None
header = {}
header["name"] = _retrieve_buffer(conn, length)
length = _retrieve_int(conn)
header["value"] = _retrieve_buffer(conn, length)
return header
def _retrieve_input_data(conn):
"""
MSG Frame Format:
| parameter_name |
| content_type |
| input data in bytes |
"""
decode_req = os.environ.get("TS_DECODE_INPUT_REQUEST")
length = _retrieve_int(conn)
if length == -1:
return None
model_input = {}
model_input["name"] = _retrieve_buffer(conn, length).decode("utf-8")
length = _retrieve_int(conn)
content_type = _retrieve_buffer(conn, length).decode("utf-8")
model_input["contentType"] = content_type
length = _retrieve_int(conn)
value = _retrieve_buffer(conn, length)
if content_type == "application/json" and (
decode_req is None or decode_req == "true"
):
try:
model_input["value"] = json.loads(value.decode("utf-8"))
except Exception as e:
model_input["value"] = value
logging.warning(
"Failed json decoding of input data. Forwarding encoded payload",
exc_info=True,
)
elif content_type.startswith("text") and (
decode_req is None or decode_req == "true"
):
try:
model_input["value"] = value.decode("utf-8")
except Exception as e:
model_input["value"] = value
logging.warning(
"Failed utf-8 decoding of input data. Forwarding encoded payload",
exc_info=True,
)
else:
model_input["value"] = value
return model_input
[docs]@deprecated(
version=1.0,
replacement="ts.handler_utils.utils.send_intermediate_predict_response",
)
def send_intermediate_predict_response(ret, req_id_map, message, code, context=None):
if str(os.getenv("LOCAL_RANK", 0)) != "0":
return None
msg = create_predict_response(ret, req_id_map, message, code, context, True)
context.cl_socket.sendall(msg)