From 072f36fe53eb93527a6ccc98cd2bad14ac75a940 Mon Sep 17 00:00:00 2001 From: Saumya Gandhi Date: Thu, 7 Sep 2023 12:26:19 -0400 Subject: [PATCH] Refactor OpenAI exception handling from instruction parser (#328) * refactor openai exception handling from instruction parser * minor fixes not resolved properly in merge conflic * reduce branching in prompt parser * add comments to clarify changes * pre-commit checks * linting issues --- prompt2model/prompt_parser/instr_parser.py | 65 ++++++++++++---------- prompt2model/utils/api_tools.py | 37 +++++++----- tests/prompt_parser_test.py | 2 +- 3 files changed, 58 insertions(+), 46 deletions(-) diff --git a/prompt2model/prompt_parser/instr_parser.py b/prompt2model/prompt_parser/instr_parser.py index fc9815ed1..2afbed4f7 100644 --- a/prompt2model/prompt_parser/instr_parser.py +++ b/prompt2model/prompt_parser/instr_parser.py @@ -12,12 +12,8 @@ from prompt2model.prompt_parser.instr_parser_prompt import ( # isort: split construct_prompt_for_instruction_parsing, ) -from prompt2model.utils import ( - API_ERRORS, - APIAgent, - get_formatted_logger, - handle_api_error, -) + +from prompt2model.utils import APIAgent, get_formatted_logger logger = get_formatted_logger("PromptParser") @@ -90,29 +86,38 @@ def parse_from_prompt(self, prompt: str) -> None: chat_api = APIAgent() while True: - try: - self.api_call_counter += 1 - response = chat_api.generate_one_completion( - parsing_prompt_for_chatgpt, - temperature=0, - presence_penalty=0, - frequency_penalty=0, - ) - extraction = self.extract_response(response) - if extraction is not None: - self._instruction, self._examples = extraction - return None - else: - if ( - self.max_api_calls - and self.api_call_counter == self.max_api_calls - ): - logger.warning( - "Maximum number of API calls reached for PromptParser." - ) - return None - except API_ERRORS as e: - self.api_call_counter = handle_api_error(e, self.api_call_counter) + self.api_call_counter += 1 + response = chat_api.generate_one_completion( + parsing_prompt_for_chatgpt, + temperature=0, + presence_penalty=0, + frequency_penalty=0, + ) + + if isinstance(response, Exception): + # Generation failed due to an API related error and requires retry. + if self.max_api_calls and self.api_call_counter >= self.max_api_calls: + # In case we reach maximum number of API calls, we raise an error. logger.error("Maximum number of API calls reached.") - raise ValueError("Maximum number of API calls reached.") from e + raise ValueError( + "Maximum number of API calls reached." + ) from response + + continue # no need to proceed with extracting + # response if API call failed. + + extraction = self.extract_response(response) + + if extraction is not None: + # extraction is successful + + self._instruction, self._examples = extraction + return None + + if self.max_api_calls and self.api_call_counter == self.max_api_calls: + # In case we reach maximum number of API calls without a + # successful extraction, we return None. + + logger.warning("Maximum number of API calls reached for PromptParser.") + return None diff --git a/prompt2model/utils/api_tools.py b/prompt2model/utils/api_tools.py index 1b3baaec4..44f5f9d6f 100644 --- a/prompt2model/utils/api_tools.py +++ b/prompt2model/utils/api_tools.py @@ -70,18 +70,24 @@ def generate_one_completion( the model's likelihood of repeating the same line verbatim. Returns: - A response object. + An OpenAI-like response object if there were no errors in generation. + In case of API-specific error, Exception object is captured and returned. """ - response = completion( # completion gets the key from os.getenv - model=self.model_name, - messages=[ - {"role": "user", "content": f"{prompt}"}, - ], - temperature=temperature, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - ) - return response + try: + response = completion( # completion gets the key from os.getenv + model=self.model_name, + messages=[ + {"role": "user", "content": f"{prompt}"}, + ], + temperature=temperature, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + ) + return response + except API_ERRORS as e: + err = handle_api_error(e) + + return err async def generate_batch_completion( self, @@ -173,7 +179,7 @@ async def _throttled_completion_acreate( return responses -def handle_api_error(e, api_call_counter): +def handle_api_error(e, api_call_counter=0): """Handle OpenAI errors or related errors that the API may raise. Args: @@ -182,7 +188,8 @@ def handle_api_error(e, api_call_counter): api_call_counter: The number of API calls made so far. Returns: - The api_call_counter (if no error was raised), else raise the error. + The captured exception (if error is API related and can be retried), + else raise the error. """ logging.error(e) if isinstance( @@ -193,8 +200,8 @@ def handle_api_error(e, api_call_counter): time.sleep(1) if isinstance(e, API_ERRORS): - # For these errors, we can increment a counter and retry the API call. - return api_call_counter + # For these errors, we can return the API related error and retry the API call. + return e else: # For all other errors, immediately throw an exception. raise e diff --git a/tests/prompt_parser_test.py b/tests/prompt_parser_test.py index 6bc10706d..6b4a4f363 100644 --- a/tests/prompt_parser_test.py +++ b/tests/prompt_parser_test.py @@ -126,7 +126,7 @@ def test_instruction_parser_with_invalid_json(mocked_parsing_method): @patch("time.sleep") @patch( - "prompt2model.utils.APIAgent.generate_one_completion", + "openai.ChatCompletion.create", side_effect=openai.error.Timeout("timeout"), ) def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_method):