Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor OpenAI exception handling from instruction parser #328

Merged
merged 8 commits into from
Sep 7, 2023
Merged
65 changes: 35 additions & 30 deletions prompt2model/prompt_parser/instr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@
from prompt2model.prompt_parser.instr_parser_prompt import ( # isort: split
construct_prompt_for_instruction_parsing,
)
from prompt2model.utils import (
API_ERRORS,
APIAgent,
get_formatted_logger,
handle_api_error,
)

from prompt2model.utils import APIAgent, get_formatted_logger

logger = get_formatted_logger("PromptParser")

Expand Down Expand Up @@ -90,29 +86,38 @@ def parse_from_prompt(self, prompt: str) -> None:

chat_api = APIAgent()
while True:
try:
self.api_call_counter += 1
response = chat_api.generate_one_completion(
parsing_prompt_for_chatgpt,
temperature=0,
presence_penalty=0,
frequency_penalty=0,
)
extraction = self.extract_response(response)
if extraction is not None:
self._instruction, self._examples = extraction
return None
else:
if (
self.max_api_calls
and self.api_call_counter == self.max_api_calls
):
logger.warning(
"Maximum number of API calls reached for PromptParser."
)
return None
except API_ERRORS as e:
self.api_call_counter = handle_api_error(e, self.api_call_counter)
self.api_call_counter += 1
response = chat_api.generate_one_completion(
parsing_prompt_for_chatgpt,
temperature=0,
presence_penalty=0,
frequency_penalty=0,
)

if isinstance(response, Exception):
neubig marked this conversation as resolved.
Show resolved Hide resolved
# Generation failed due to an API related error and requires retry.

if self.max_api_calls and self.api_call_counter >= self.max_api_calls:
# In case we reach maximum number of API calls, we raise an error.
logger.error("Maximum number of API calls reached.")
raise ValueError("Maximum number of API calls reached.") from e
raise ValueError(
"Maximum number of API calls reached."
) from response

continue # no need to proceed with extracting
# response if API call failed.

extraction = self.extract_response(response)

if extraction is not None:
# extraction is successful

self._instruction, self._examples = extraction
return None

if self.max_api_calls and self.api_call_counter == self.max_api_calls:
# In case we reach maximum number of API calls without a
# successful extraction, we return None.

logger.warning("Maximum number of API calls reached for PromptParser.")
return None
37 changes: 22 additions & 15 deletions prompt2model/utils/api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,24 @@ def generate_one_completion(
the model's likelihood of repeating the same line verbatim.
Returns:
A response object.
An OpenAI-like response object if there were no errors in generation.
In case of API-specific error, Exception object is captured and returned.
"""
response = completion( # completion gets the key from os.getenv
model=self.model_name,
messages=[
{"role": "user", "content": f"{prompt}"},
],
temperature=temperature,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
return response
try:
response = completion( # completion gets the key from os.getenv
model=self.model_name,
messages=[
{"role": "user", "content": f"{prompt}"},
],
temperature=temperature,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
return response
except API_ERRORS as e:
err = handle_api_error(e)

return err

async def generate_batch_completion(
self,
Expand Down Expand Up @@ -173,7 +179,7 @@ async def _throttled_completion_acreate(
return responses


def handle_api_error(e, api_call_counter):
def handle_api_error(e, api_call_counter=0):
"""Handle OpenAI errors or related errors that the API may raise.
Args:
Expand All @@ -182,7 +188,8 @@ def handle_api_error(e, api_call_counter):
api_call_counter: The number of API calls made so far.
Returns:
The api_call_counter (if no error was raised), else raise the error.
The captured exception (if error is API related and can be retried),
else raise the error.
"""
logging.error(e)
if isinstance(
Expand All @@ -193,8 +200,8 @@ def handle_api_error(e, api_call_counter):
time.sleep(1)

if isinstance(e, API_ERRORS):
# For these errors, we can increment a counter and retry the API call.
return api_call_counter
# For these errors, we can return the API related error and retry the API call.
return e
else:
# For all other errors, immediately throw an exception.
raise e
Expand Down
2 changes: 1 addition & 1 deletion tests/prompt_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_instruction_parser_with_invalid_json(mocked_parsing_method):

@patch("time.sleep")
@patch(
"prompt2model.utils.APIAgent.generate_one_completion",
"openai.ChatCompletion.create",
side_effect=openai.error.Timeout("timeout"),
)
def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_method):
Expand Down