Skip to content

Commit

Permalink
fix: more error handling for openai_plugin, fix nits
Browse files Browse the repository at this point in the history
This commit adds a try-catch to the openai_plugin to handle empty
chunks in the response which cause the requests package to raise
an exception. It also fixes two minor logging nits.
  • Loading branch information
dagrayvid committed Jun 6, 2024
1 parent 3d76871 commit 6f76942
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 63 deletions.
2 changes: 1 addition & 1 deletion load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def main(args):
config = utils.yaml_load(args.config)
concurrency, duration, plugin = utils.parse_config(config)
except Exception as e:
logging.error("Exiting due to invalid input: %s", e)
logging.error("Exiting due to invalid input: %s", repr(e))
exit_gracefully(procs, warmup_q, dataset_q, stop_q, logger_q, log_reader_thread, 1)

try:
Expand Down
133 changes: 72 additions & 61 deletions plugins/openai_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,69 +172,80 @@ def streaming_request_http(self, query: dict, user_id: int, test_end_time: float

logger.debug("Response: %s", response)
message = None
for line in response.iter_lines():
logger.debug("response line: %s", line)
_, found, data = line.partition(b"data: ")
if found and data != b"[DONE]":
try:
message = json.loads(data)
logger.debug("Message: %s", message)
if "/v1/chat/completions" in self.host and not message["choices"][0]['delta'].get('content'):
message["choices"][0]['delta']['content']=""
error = message.get("error")
if error is None:
if "/v1/chat/completions" in self.host:
token = message["choices"][0]['delta']['content']
try:
for line in response.iter_lines():
logger.debug("response line: %s", line)
_, found, data = line.partition(b"data: ")
if found and data != b"[DONE]":
try:
message = json.loads(data)
logger.debug("Message: %s", message)
if "/v1/chat/completions" in self.host and not message["choices"][0]['delta'].get('content'):
message["choices"][0]['delta']['content']=""
error = message.get("error")
if error is None:
if "/v1/chat/completions" in self.host:
token = message["choices"][0]['delta']['content']
else:
token = message["choices"][0]["text"]
logger.debug("Token: %s", token)
else:
token = message["choices"][0]["text"]
logger.debug("Token: %s", token)
else:
result.error_code = response.status_code
result.error_text = error
logger.error("Error received in response message: %s", error)
break
except json.JSONDecodeError:
logger.exception("Response line could not be json decoded: %s", line)
except KeyError:
logger.exception(
"KeyError, unexpected response format in line: %s", line
)
result.error_code = response.status_code
result.error_text = error
logger.error("Error received in response message: %s", error)
break
except json.JSONDecodeError:
logger.exception("Response line could not be json decoded: %s", line)
except KeyError:
logger.exception(
"KeyError, unexpected response format in line: %s", line
)
continue
else:
continue
else:
continue

try:
# First chunk may not be a token, just a connection ack
if not result.ack_time:
result.ack_time = time.time()

# First non empty token is the first token
if not result.first_token_time and token != "":
result.first_token_time = time.time()

# If the current token time is outside the test duration, record the total tokens received before
# the current token.
if (
not result.output_tokens_before_timeout
and time.time() > test_end_time
):
result.output_tokens_before_timeout = len(tokens)

tokens.append(token)

# Last token comes with finish_reason set.
if message.get("choices", [])[0].get("finish_reason", None):
result.output_tokens = message["usage"]["completion_tokens"]
result.input_tokens = message["usage"]["prompt_tokens"]
result.stop_reason = message["choices"][0]["finish_reason"]

# If test duration timeout didn't happen before the last token is received,
# total tokens before the timeout will be equal to the total tokens in the response.
if not result.output_tokens_before_timeout:
result.output_tokens_before_timeout = result.output_tokens

except KeyError:
logging.exception("KeyError, unexpected response format in line: %s", line)

try:
# First chunk may not be a token, just a connection ack
if not result.ack_time:
result.ack_time = time.time()

# First non empty token is the first token
if not result.first_token_time and token != "":
result.first_token_time = time.time()

# If the current token time is outside the test duration, record the total tokens received before
# the current token.
if (
not result.output_tokens_before_timeout
and time.time() > test_end_time
):
result.output_tokens_before_timeout = len(tokens)

tokens.append(token)

# Last token comes with finish_reason set.
if message.get("choices", [])[0].get("finish_reason", None):
result.output_tokens = message["usage"]["completion_tokens"]
result.input_tokens = message["usage"]["prompt_tokens"]
result.stop_reason = message["choices"][0]["finish_reason"]

# If test duration timeout didn't happen before the last token is received,
# total tokens before the timeout will be equal to the total tokens in the response.
if not result.output_tokens_before_timeout:
result.output_tokens_before_timeout = result.output_tokens

except KeyError:
logger.exception("KeyError, unexpected response format in line: %s", line)
except requests.exceptions.ChunkedEncodingError as err:
result.end_time = time.time()
result.error_text = repr(err)
result.output_text = "".join(tokens)
result.output_tokens = len(tokens)
if response is not None:
result.error_code = response.status_code
logger.exception("ChunkedEncodingError while streaming response")
return result


# Full response received, return
result.end_time = time.time()
Expand Down
2 changes: 1 addition & 1 deletion plugins/tgis_grpc_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

logger = logging.getLogger("user")

required_args = ["model_name", "host", "port", "streaming"]
required_args = ["model_name", "host", "port", "streaming", "use_tls"]

"""
This plugin currently only supports grpc requests for a standalone TGI server.
Expand Down

0 comments on commit 6f76942

Please sign in to comment.