Skip to content

Commit

Permalink
Refactor OpenAI exception handling from instruction parser (#328)
Browse files Browse the repository at this point in the history
* refactor openai exception handling from instruction parser

* minor fixes not resolved properly in merge conflic

* reduce branching in prompt parser

* add comments to clarify changes

* pre-commit checks

* linting issues
  • Loading branch information
saum7800 authored Sep 7, 2023
1 parent a21ef8d commit 072f36f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 46 deletions.
65 changes: 35 additions & 30 deletions prompt2model/prompt_parser/instr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@
from prompt2model.prompt_parser.instr_parser_prompt import ( # isort: split
construct_prompt_for_instruction_parsing,
)
from prompt2model.utils import (
API_ERRORS,
APIAgent,
get_formatted_logger,
handle_api_error,
)

from prompt2model.utils import APIAgent, get_formatted_logger

logger = get_formatted_logger("PromptParser")

Expand Down Expand Up @@ -90,29 +86,38 @@ def parse_from_prompt(self, prompt: str) -> None:

chat_api = APIAgent()
while True:
try:
self.api_call_counter += 1
response = chat_api.generate_one_completion(
parsing_prompt_for_chatgpt,
temperature=0,
presence_penalty=0,
frequency_penalty=0,
)
extraction = self.extract_response(response)
if extraction is not None:
self._instruction, self._examples = extraction
return None
else:
if (
self.max_api_calls
and self.api_call_counter == self.max_api_calls
):
logger.warning(
"Maximum number of API calls reached for PromptParser."
)
return None
except API_ERRORS as e:
self.api_call_counter = handle_api_error(e, self.api_call_counter)
self.api_call_counter += 1
response = chat_api.generate_one_completion(
parsing_prompt_for_chatgpt,
temperature=0,
presence_penalty=0,
frequency_penalty=0,
)

if isinstance(response, Exception):
# Generation failed due to an API related error and requires retry.

if self.max_api_calls and self.api_call_counter >= self.max_api_calls:
# In case we reach maximum number of API calls, we raise an error.
logger.error("Maximum number of API calls reached.")
raise ValueError("Maximum number of API calls reached.") from e
raise ValueError(
"Maximum number of API calls reached."
) from response

continue # no need to proceed with extracting
# response if API call failed.

extraction = self.extract_response(response)

if extraction is not None:
# extraction is successful

self._instruction, self._examples = extraction
return None

if self.max_api_calls and self.api_call_counter == self.max_api_calls:
# In case we reach maximum number of API calls without a
# successful extraction, we return None.

logger.warning("Maximum number of API calls reached for PromptParser.")
return None
37 changes: 22 additions & 15 deletions prompt2model/utils/api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,24 @@ def generate_one_completion(
the model's likelihood of repeating the same line verbatim.
Returns:
A response object.
An OpenAI-like response object if there were no errors in generation.
In case of API-specific error, Exception object is captured and returned.
"""
response = completion( # completion gets the key from os.getenv
model=self.model_name,
messages=[
{"role": "user", "content": f"{prompt}"},
],
temperature=temperature,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
return response
try:
response = completion( # completion gets the key from os.getenv
model=self.model_name,
messages=[
{"role": "user", "content": f"{prompt}"},
],
temperature=temperature,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
return response
except API_ERRORS as e:
err = handle_api_error(e)

return err

async def generate_batch_completion(
self,
Expand Down Expand Up @@ -173,7 +179,7 @@ async def _throttled_completion_acreate(
return responses


def handle_api_error(e, api_call_counter):
def handle_api_error(e, api_call_counter=0):
"""Handle OpenAI errors or related errors that the API may raise.
Args:
Expand All @@ -182,7 +188,8 @@ def handle_api_error(e, api_call_counter):
api_call_counter: The number of API calls made so far.
Returns:
The api_call_counter (if no error was raised), else raise the error.
The captured exception (if error is API related and can be retried),
else raise the error.
"""
logging.error(e)
if isinstance(
Expand All @@ -193,8 +200,8 @@ def handle_api_error(e, api_call_counter):
time.sleep(1)

if isinstance(e, API_ERRORS):
# For these errors, we can increment a counter and retry the API call.
return api_call_counter
# For these errors, we can return the API related error and retry the API call.
return e
else:
# For all other errors, immediately throw an exception.
raise e
Expand Down
2 changes: 1 addition & 1 deletion tests/prompt_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_instruction_parser_with_invalid_json(mocked_parsing_method):

@patch("time.sleep")
@patch(
"prompt2model.utils.APIAgent.generate_one_completion",
"openai.ChatCompletion.create",
side_effect=openai.error.Timeout("timeout"),
)
def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_method):
Expand Down

0 comments on commit 072f36f

Please sign in to comment.