Skip to content

Commit

Permalink
Migrate to the latest openAI version (#397)
Browse files Browse the repository at this point in the history
* run migration

* bump up versions

* minor fixes and migrations

* reset token count

* minor fix

* correct timeout error

* fix prompt parser test
  • Loading branch information
saum7800 authored Apr 16, 2024
1 parent 25e0a96 commit 04debc9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 20 deletions.
31 changes: 15 additions & 16 deletions prompt2model/utils/api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import aiolimiter
import litellm.utils
import openai
import openai.error
import tiktoken
from aiohttp import ClientSession
from litellm import acompletion, completion
Expand All @@ -19,22 +18,22 @@
# Note that litellm converts all API errors into openai errors,
# so openai errors are valid even when using other services.
API_ERRORS = (
openai.error.APIError,
openai.error.Timeout,
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
openai.error.InvalidRequestError,
openai.APIError,
openai.APITimeoutError,
openai.RateLimitError,
openai.BadRequestError,
openai.APIStatusError,
json.decoder.JSONDecodeError,
AssertionError,
)

ERROR_ERRORS_TO_MESSAGES = {
openai.error.InvalidRequestError: "API Invalid Request: Prompt was filtered",
openai.error.RateLimitError: "API rate limit exceeded. Sleeping for 10 seconds.",
openai.error.APIConnectionError: "Error Communicating with API",
openai.error.Timeout: "API Timeout Error: API Timeout",
openai.error.ServiceUnavailableError: "API service unavailable error: {e}",
openai.error.APIError: "API error: {e}",
openai.BadRequestError: "API Invalid Request: Prompt was filtered",
openai.RateLimitError: "API rate limit exceeded. Sleeping for 10 seconds.",
openai.APIConnectionError: "Error Communicating with API",
openai.APITimeoutError: "API Timeout Error: API Timeout",
openai.APIStatusError: "API service unavailable error: {e}",
openai.APIError: "API error: {e}",
}


Expand Down Expand Up @@ -170,14 +169,14 @@ async def _throttled_completion_acreate(
if isinstance(
e,
(
openai.error.ServiceUnavailableError,
openai.error.APIError,
openai.APIStatusError,
openai.APIError,
),
):
logging.warning(
ERROR_ERRORS_TO_MESSAGES[type(e)].format(e=e)
)
elif isinstance(e, openai.error.InvalidRequestError):
elif isinstance(e, openai.BadRequestError):
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)])
return {
"choices": [
Expand Down Expand Up @@ -231,7 +230,7 @@ def handle_api_error(e) -> None:
raise e
if isinstance(
e,
(openai.error.APIError, openai.error.Timeout, openai.error.RateLimitError),
(openai.APIError, openai.APITimeoutError, openai.RateLimitError),
):
# For these errors, OpenAI recommends waiting before retrying.
time.sleep(1)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"gradio==3.38.0",
"torch",
"pytest",
"openai==0.27.10",
"openai",
"sentencepiece",
"bert_score",
"sacrebleu",
Expand All @@ -45,7 +45,7 @@ dependencies = [
"psutil",
"protobuf",
"nest-asyncio",
"litellm==0.1.583",
"litellm",
"peft"
]

Expand Down
4 changes: 2 additions & 2 deletions tests/prompt_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_instruction_parser_with_invalid_json(mocked_parsing_method):
@patch("time.sleep")
@patch(
"prompt2model.utils.APIAgent.generate_one_completion",
side_effect=openai.error.Timeout("timeout"),
side_effect=openai.APITimeoutError("timeout"),
)
def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_method):
"""Verify that we wait and retry (a set number of times) if the API times out.
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_met
assert isinstance(exc_info.value, RuntimeError)
# Check if the original exception (e) is present as the cause
original_exception = exc_info.value.__cause__
assert isinstance(original_exception, openai.error.Timeout)
assert isinstance(original_exception, openai.APITimeoutError)
gc.collect()


Expand Down

0 comments on commit 04debc9

Please sign in to comment.