From 7cd9110e46e73936ca01e44510b3a1bf72ae2be8 Mon Sep 17 00:00:00 2001 From: saum7800 Date: Wed, 6 Sep 2023 15:33:31 -0400 Subject: [PATCH 1/6] refactor openai exception handling from instruction parser --- prompt2model/prompt_parser/instr_parser.py | 57 ++++++++++------------ prompt2model/utils/openai_tools.py | 29 ++++++----- tests/prompt_parser_test.py | 2 +- 3 files changed, 45 insertions(+), 43 deletions(-) diff --git a/prompt2model/prompt_parser/instr_parser.py b/prompt2model/prompt_parser/instr_parser.py index 571450b64..52e2f3b2e 100644 --- a/prompt2model/prompt_parser/instr_parser.py +++ b/prompt2model/prompt_parser/instr_parser.py @@ -12,12 +12,7 @@ from prompt2model.prompt_parser.instr_parser_prompt import ( # isort: split construct_prompt_for_instruction_parsing, ) -from prompt2model.utils import ( - OPENAI_ERRORS, - ChatGPTAgent, - get_formatted_logger, - handle_openai_error, -) +from prompt2model.utils import ChatGPTAgent, get_formatted_logger logger = get_formatted_logger("PromptParser") @@ -100,29 +95,31 @@ def parse_from_prompt(self, prompt: str) -> None: chat_api = ChatGPTAgent(self.api_key) while True: - try: - self.api_call_counter += 1 - response = chat_api.generate_one_openai_chat_completion( - parsing_prompt_for_chatgpt, - temperature=0, - presence_penalty=0, - frequency_penalty=0, - ) - extraction = self.extract_response(response) - if extraction is not None: - self._instruction, self._examples = extraction - return None - else: - if ( - self.max_api_calls - and self.api_call_counter == self.max_api_calls - ): - logger.warning( - "Maximum number of API calls reached for PromptParser." - ) - return None - except OPENAI_ERRORS as e: - self.api_call_counter = handle_openai_error(e, self.api_call_counter) + self.api_call_counter += 1 + response = chat_api.generate_one_openai_chat_completion( + parsing_prompt_for_chatgpt, + temperature=0, + presence_penalty=0, + frequency_penalty=0, + ) + + if isinstance(response, Exception): + extraction = None + if self.max_api_calls and self.api_call_counter >= self.max_api_calls: logger.error("Maximum number of API calls reached.") - raise ValueError("Maximum number of API calls reached.") from e + raise ValueError( + "Maximum number of API calls reached." + ) from response + else: + extraction = self.extract_response(response) + + if extraction is not None: + self._instruction, self._examples = extraction + return None + else: + if self.max_api_calls and self.api_call_counter == self.max_api_calls: + logger.warning( + "Maximum number of API calls reached for PromptParser." + ) + return None diff --git a/prompt2model/utils/openai_tools.py b/prompt2model/utils/openai_tools.py index b347ef2bc..e7829c529 100644 --- a/prompt2model/utils/openai_tools.py +++ b/prompt2model/utils/openai_tools.py @@ -78,16 +78,21 @@ def generate_one_openai_chat_completion( Returns: A response object. """ - response = openai.ChatCompletion.create( - model=self.model_name, - messages=[ - {"role": "user", "content": f"{prompt}"}, - ], - temperature=temperature, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - ) - return response + try: + response = openai.ChatCompletion.create( + model=self.model_name, + messages=[ + {"role": "user", "content": f"{prompt}"}, + ], + temperature=temperature, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + ) + return response + except OPENAI_ERRORS as e: + err = handle_openai_error(e) + + return err async def generate_batch_openai_chat_completion( self, @@ -179,7 +184,7 @@ async def _throttled_openai_chat_completion_acreate( return responses -def handle_openai_error(e, api_call_counter): +def handle_openai_error(e, api_call_counter=0): """Handle OpenAI errors or related errors that the OpenAI API may raise. Args: @@ -200,7 +205,7 @@ def handle_openai_error(e, api_call_counter): if isinstance(e, OPENAI_ERRORS): # For these errors, we can increment a counter and retry the API call. - return api_call_counter + return e else: # For all other errors, immediately throw an exception. raise e diff --git a/tests/prompt_parser_test.py b/tests/prompt_parser_test.py index 4bacdf7a7..7c1aba28d 100644 --- a/tests/prompt_parser_test.py +++ b/tests/prompt_parser_test.py @@ -132,7 +132,7 @@ def test_instruction_parser_with_invalid_json(mocked_parsing_method): @patch("time.sleep") @patch( - "prompt2model.utils.ChatGPTAgent.generate_one_openai_chat_completion", + "openai.ChatCompletion.create", side_effect=openai.error.Timeout("timeout"), ) def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_method): From 6a19836fb3c81a9348eac08982e3e46cc46c83ac Mon Sep 17 00:00:00 2001 From: saum7800 Date: Thu, 7 Sep 2023 10:25:40 -0400 Subject: [PATCH 2/6] minor fixes not resolved properly in merge conflic --- prompt2model/prompt_parser/instr_parser.py | 2 +- prompt2model/utils/api_tools.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/prompt2model/prompt_parser/instr_parser.py b/prompt2model/prompt_parser/instr_parser.py index 70135c82d..a75e1c074 100644 --- a/prompt2model/prompt_parser/instr_parser.py +++ b/prompt2model/prompt_parser/instr_parser.py @@ -87,7 +87,7 @@ def parse_from_prompt(self, prompt: str) -> None: chat_api = APIAgent() while True: self.api_call_counter += 1 - response = chat_api.generate_one_openai_chat_completion( + response = chat_api.generate_one_completion( parsing_prompt_for_chatgpt, temperature=0, presence_penalty=0, diff --git a/prompt2model/utils/api_tools.py b/prompt2model/utils/api_tools.py index ca99a96ad..6a3d3e18d 100644 --- a/prompt2model/utils/api_tools.py +++ b/prompt2model/utils/api_tools.py @@ -84,7 +84,7 @@ def generate_one_completion( ) return response except API_ERRORS as e: - err = handle_openai_error(e) + err = handle_api_error(e) return err @@ -178,7 +178,7 @@ async def _throttled_completion_acreate( return responses -def handle_openai_error(e, api_call_counter=0): +def handle_api_error(e, api_call_counter=0): """Handle OpenAI errors or related errors that the API may raise. Args: From 8708bba1fa5aafa2bb766c6345a40991c10d4bc0 Mon Sep 17 00:00:00 2001 From: saum7800 Date: Thu, 7 Sep 2023 10:42:46 -0400 Subject: [PATCH 3/6] reduce branching in prompt parser --- prompt2model/prompt_parser/instr_parser.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/prompt2model/prompt_parser/instr_parser.py b/prompt2model/prompt_parser/instr_parser.py index a75e1c074..12139de1f 100644 --- a/prompt2model/prompt_parser/instr_parser.py +++ b/prompt2model/prompt_parser/instr_parser.py @@ -95,21 +95,22 @@ def parse_from_prompt(self, prompt: str) -> None: ) if isinstance(response, Exception): - extraction = None if self.max_api_calls and self.api_call_counter >= self.max_api_calls: logger.error("Maximum number of API calls reached.") raise ValueError( "Maximum number of API calls reached." ) from response - else: - extraction = self.extract_response(response) + + continue + + extraction = self.extract_response(response) if extraction is not None: self._instruction, self._examples = extraction return None - else: - if self.max_api_calls and self.api_call_counter == self.max_api_calls: - logger.warning( - "Maximum number of API calls reached for PromptParser." - ) - return None + + if self.max_api_calls and self.api_call_counter == self.max_api_calls: + logger.warning( + "Maximum number of API calls reached for PromptParser." + ) + return None From e60e177238d5a8e27b1887c8a0b34b06f6803f70 Mon Sep 17 00:00:00 2001 From: saum7800 Date: Thu, 7 Sep 2023 11:46:18 -0400 Subject: [PATCH 4/6] add comments to clarify changes --- prompt2model/prompt_parser/instr_parser.py | 9 ++++++++- prompt2model/utils/api_tools.py | 7 ++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/prompt2model/prompt_parser/instr_parser.py b/prompt2model/prompt_parser/instr_parser.py index 12139de1f..10629d493 100644 --- a/prompt2model/prompt_parser/instr_parser.py +++ b/prompt2model/prompt_parser/instr_parser.py @@ -95,21 +95,28 @@ def parse_from_prompt(self, prompt: str) -> None: ) if isinstance(response, Exception): + # Generation failed due to an API related error and requires retry. + if self.max_api_calls and self.api_call_counter >= self.max_api_calls: + # In case we reach maximum number of API calls, we raise an error. logger.error("Maximum number of API calls reached.") raise ValueError( "Maximum number of API calls reached." ) from response - continue + continue # no need to proceed with extracting response if API call failed extraction = self.extract_response(response) if extraction is not None: + # extraction is successful + self._instruction, self._examples = extraction return None if self.max_api_calls and self.api_call_counter == self.max_api_calls: + # In case we reach maximum number of API calls without a successful extraction, we return None. + logger.warning( "Maximum number of API calls reached for PromptParser." ) diff --git a/prompt2model/utils/api_tools.py b/prompt2model/utils/api_tools.py index 6a3d3e18d..3e19f3cda 100644 --- a/prompt2model/utils/api_tools.py +++ b/prompt2model/utils/api_tools.py @@ -70,7 +70,8 @@ def generate_one_completion( the model's likelihood of repeating the same line verbatim. Returns: - A response object. + An OpenAI-like response object if there were no errors in generating completion. + In case of API-specific error, Exception object is captured and returned. """ try: response = completion( # completion gets the key from os.getenv @@ -187,7 +188,7 @@ def handle_api_error(e, api_call_counter=0): api_call_counter: The number of API calls made so far. Returns: - The api_call_counter (if no error was raised), else raise the error. + The captured exception (if error is API related and can be retried), else raise the error. """ logging.error(e) if isinstance( @@ -198,7 +199,7 @@ def handle_api_error(e, api_call_counter=0): time.sleep(1) if isinstance(e, API_ERRORS): - # For these errors, we can increment a counter and retry the API call. + # For these errors, we can return the API related error and retry the API call. return e else: # For all other errors, immediately throw an exception. From 706a348953060c65bfc30073e4045e5c1e35339e Mon Sep 17 00:00:00 2001 From: saum7800 Date: Thu, 7 Sep 2023 12:02:29 -0400 Subject: [PATCH 5/6] pre-commit checks --- prompt2model/prompt_parser/instr_parser.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/prompt2model/prompt_parser/instr_parser.py b/prompt2model/prompt_parser/instr_parser.py index 10629d493..2afbed4f7 100644 --- a/prompt2model/prompt_parser/instr_parser.py +++ b/prompt2model/prompt_parser/instr_parser.py @@ -103,9 +103,10 @@ def parse_from_prompt(self, prompt: str) -> None: raise ValueError( "Maximum number of API calls reached." ) from response - - continue # no need to proceed with extracting response if API call failed - + + continue # no need to proceed with extracting + # response if API call failed. + extraction = self.extract_response(response) if extraction is not None: @@ -115,9 +116,8 @@ def parse_from_prompt(self, prompt: str) -> None: return None if self.max_api_calls and self.api_call_counter == self.max_api_calls: - # In case we reach maximum number of API calls without a successful extraction, we return None. - - logger.warning( - "Maximum number of API calls reached for PromptParser." - ) + # In case we reach maximum number of API calls without a + # successful extraction, we return None. + + logger.warning("Maximum number of API calls reached for PromptParser.") return None From 56ef493979541d30df8dec8957386c8ad03b342f Mon Sep 17 00:00:00 2001 From: saum7800 Date: Thu, 7 Sep 2023 12:10:41 -0400 Subject: [PATCH 6/6] linting issues --- prompt2model/utils/api_tools.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/prompt2model/utils/api_tools.py b/prompt2model/utils/api_tools.py index 3e19f3cda..44f5f9d6f 100644 --- a/prompt2model/utils/api_tools.py +++ b/prompt2model/utils/api_tools.py @@ -70,7 +70,7 @@ def generate_one_completion( the model's likelihood of repeating the same line verbatim. Returns: - An OpenAI-like response object if there were no errors in generating completion. + An OpenAI-like response object if there were no errors in generation. In case of API-specific error, Exception object is captured and returned. """ try: @@ -188,7 +188,8 @@ def handle_api_error(e, api_call_counter=0): api_call_counter: The number of API calls made so far. Returns: - The captured exception (if error is API related and can be retried), else raise the error. + The captured exception (if error is API related and can be retried), + else raise the error. """ logging.error(e) if isinstance(