Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor OpenAI exception handling from instruction parser #328

Merged
merged 8 commits into from
Sep 7, 2023
57 changes: 27 additions & 30 deletions prompt2model/prompt_parser/instr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
from prompt2model.prompt_parser.instr_parser_prompt import ( # isort: split
construct_prompt_for_instruction_parsing,
)
from prompt2model.utils import (
OPENAI_ERRORS,
ChatGPTAgent,
get_formatted_logger,
handle_openai_error,
)
from prompt2model.utils import ChatGPTAgent, get_formatted_logger

logger = get_formatted_logger("PromptParser")

Expand Down Expand Up @@ -100,29 +95,31 @@ def parse_from_prompt(self, prompt: str) -> None:

chat_api = ChatGPTAgent(self.api_key)
while True:
try:
self.api_call_counter += 1
response = chat_api.generate_one_openai_chat_completion(
parsing_prompt_for_chatgpt,
temperature=0,
presence_penalty=0,
frequency_penalty=0,
)
extraction = self.extract_response(response)
if extraction is not None:
self._instruction, self._examples = extraction
return None
else:
if (
self.max_api_calls
and self.api_call_counter == self.max_api_calls
):
logger.warning(
"Maximum number of API calls reached for PromptParser."
)
return None
except OPENAI_ERRORS as e:
self.api_call_counter = handle_openai_error(e, self.api_call_counter)
self.api_call_counter += 1
response = chat_api.generate_one_openai_chat_completion(
parsing_prompt_for_chatgpt,
temperature=0,
presence_penalty=0,
frequency_penalty=0,
)

if isinstance(response, Exception):
neubig marked this conversation as resolved.
Show resolved Hide resolved
extraction = None

if self.max_api_calls and self.api_call_counter >= self.max_api_calls:
logger.error("Maximum number of API calls reached.")
raise ValueError("Maximum number of API calls reached.") from e
raise ValueError(
"Maximum number of API calls reached."
) from response
else:
extraction = self.extract_response(response)

if extraction is not None:
self._instruction, self._examples = extraction
return None
else:
if self.max_api_calls and self.api_call_counter == self.max_api_calls:
logger.warning(
"Maximum number of API calls reached for PromptParser."
)
return None
29 changes: 17 additions & 12 deletions prompt2model/utils/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,21 @@ def generate_one_openai_chat_completion(
Returns:
A response object.
"""
response = completion( # completion gets the key from os.getenv
model=self.model_name,
messages=[
{"role": "user", "content": f"{prompt}"},
],
temperature=temperature,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
return response
try:
response = completion( # completion gets the key from os.getenv
model=self.model_name,
messages=[
{"role": "user", "content": f"{prompt}"},
],
temperature=temperature,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
return response
except OPENAI_ERRORS as e:
err = handle_openai_error(e)

return err

async def generate_batch_openai_chat_completion(
self,
Expand Down Expand Up @@ -179,7 +184,7 @@ async def _throttled_openai_chat_completion_acreate(
return responses


def handle_openai_error(e, api_call_counter):
def handle_openai_error(e, api_call_counter=0):
"""Handle OpenAI errors or related errors that the OpenAI API may raise.

Args:
Expand All @@ -200,7 +205,7 @@ def handle_openai_error(e, api_call_counter):

if isinstance(e, OPENAI_ERRORS):
# For these errors, we can increment a counter and retry the API call.
return api_call_counter
return e
else:
# For all other errors, immediately throw an exception.
raise e
Expand Down
2 changes: 1 addition & 1 deletion tests/prompt_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_instruction_parser_with_invalid_json(mocked_parsing_method):

@patch("time.sleep")
@patch(
"prompt2model.utils.ChatGPTAgent.generate_one_openai_chat_completion",
"openai.ChatCompletion.create",
side_effect=openai.error.Timeout("timeout"),
)
def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_method):
Expand Down