diff --git a/prompt2model/utils/api_tools.py b/prompt2model/utils/api_tools.py index 0a2748329..fb3e8d17f 100644 --- a/prompt2model/utils/api_tools.py +++ b/prompt2model/utils/api_tools.py @@ -10,7 +10,6 @@ import aiolimiter import litellm.utils import openai -import openai.error import tiktoken from aiohttp import ClientSession from litellm import acompletion, completion @@ -19,22 +18,22 @@ # Note that litellm converts all API errors into openai errors, # so openai errors are valid even when using other services. API_ERRORS = ( - openai.error.APIError, - openai.error.Timeout, - openai.error.RateLimitError, - openai.error.ServiceUnavailableError, - openai.error.InvalidRequestError, + openai.APIError, + openai.APITimeoutError, + openai.RateLimitError, + openai.BadRequestError, + openai.APIStatusError, json.decoder.JSONDecodeError, AssertionError, ) ERROR_ERRORS_TO_MESSAGES = { - openai.error.InvalidRequestError: "API Invalid Request: Prompt was filtered", - openai.error.RateLimitError: "API rate limit exceeded. Sleeping for 10 seconds.", - openai.error.APIConnectionError: "Error Communicating with API", - openai.error.Timeout: "API Timeout Error: API Timeout", - openai.error.ServiceUnavailableError: "API service unavailable error: {e}", - openai.error.APIError: "API error: {e}", + openai.BadRequestError: "API Invalid Request: Prompt was filtered", + openai.RateLimitError: "API rate limit exceeded. Sleeping for 10 seconds.", + openai.APIConnectionError: "Error Communicating with API", + openai.APITimeoutError: "API Timeout Error: API Timeout", + openai.APIStatusError: "API service unavailable error: {e}", + openai.APIError: "API error: {e}", } @@ -170,14 +169,14 @@ async def _throttled_completion_acreate( if isinstance( e, ( - openai.error.ServiceUnavailableError, - openai.error.APIError, + openai.APIStatusError, + openai.APIError, ), ): logging.warning( ERROR_ERRORS_TO_MESSAGES[type(e)].format(e=e) ) - elif isinstance(e, openai.error.InvalidRequestError): + elif isinstance(e, openai.BadRequestError): logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)]) return { "choices": [ @@ -231,7 +230,7 @@ def handle_api_error(e) -> None: raise e if isinstance( e, - (openai.error.APIError, openai.error.Timeout, openai.error.RateLimitError), + (openai.APIError, openai.APITimeoutError, openai.RateLimitError), ): # For these errors, OpenAI recommends waiting before retrying. time.sleep(1) diff --git a/pyproject.toml b/pyproject.toml index 87f5642d2..5c0f34556 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "gradio==3.38.0", "torch", "pytest", - "openai==0.27.10", + "openai", "sentencepiece", "bert_score", "sacrebleu", @@ -45,7 +45,7 @@ dependencies = [ "psutil", "protobuf", "nest-asyncio", - "litellm==0.1.583", + "litellm", "peft" ] diff --git a/tests/prompt_parser_test.py b/tests/prompt_parser_test.py index 1cf2daf1c..f33321a69 100644 --- a/tests/prompt_parser_test.py +++ b/tests/prompt_parser_test.py @@ -136,7 +136,7 @@ def test_instruction_parser_with_invalid_json(mocked_parsing_method): @patch("time.sleep") @patch( "prompt2model.utils.APIAgent.generate_one_completion", - side_effect=openai.error.Timeout("timeout"), + side_effect=openai.APITimeoutError("timeout"), ) def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_method): """Verify that we wait and retry (a set number of times) if the API times out. @@ -165,7 +165,7 @@ def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_met assert isinstance(exc_info.value, RuntimeError) # Check if the original exception (e) is present as the cause original_exception = exc_info.value.__cause__ - assert isinstance(original_exception, openai.error.Timeout) + assert isinstance(original_exception, openai.APITimeoutError) gc.collect()