Skip to content

Commit

Permalink
chore: migrated to openai>=1.0 and langchain-openai libs (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Mar 5, 2024
1 parent 68f3873 commit 2ec8712
Show file tree
Hide file tree
Showing 6 changed files with 346 additions and 368 deletions.
636 changes: 307 additions & 329 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ anthropic = "0.2.10"
colorama = "0.4.4"
fastapi = "0.109.2"
flask = "2.3.2"
openai = "0.27.8"
openai = "1.13.3"
uvicorn = "0.23.2"
pydantic = "1.10.12"
defusedxml = "^0.7.1"
Expand All @@ -45,7 +45,8 @@ flake8 = "6.0.0"

[tool.poetry.group.dev.dependencies]
nox = "^2023.4.22"
langchain = "0.0.329"
langchain-openai = "0.0.8"
langchain-core = "0.1.29"

[tool.pytest.ini_options]
addopts = "--doctest-modules"
Expand Down
13 changes: 6 additions & 7 deletions tests/integration_tests/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
from dataclasses import dataclass
from typing import Callable, List, Optional

import openai
import openai.error
import pytest
from langchain.schema import BaseMessage
from langchain_core.messages import BaseMessage
from openai import APIStatusError

from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment
from tests.conftest import TEST_SERVER_URL
from tests.utils.langchain import (
ai,
create_model,
create_chat_model,
run_model,
sanitize_test_name,
sys,
Expand Down Expand Up @@ -193,16 +192,16 @@ def get_test_cases(
ids=lambda test: test.get_id(),
)
async def test_chat_completion_langchain(server, test: TestCase):
model = create_model(
model = create_chat_model(
TEST_SERVER_URL, test.deployment.value, test.streaming, test.max_tokens
)

if isinstance(test.test, Exception):
with pytest.raises(Exception) as exc_info:
await run_model(model, test.messages, test.streaming, test.stop)

assert isinstance(exc_info.value, openai.error.OpenAIError)
assert exc_info.value.http_status == 422
assert isinstance(exc_info.value, APIStatusError)
assert exc_info.value.status_code == 422
assert re.search(str(test.test), str(exc_info.value))
else:
actual_output = await run_model(
Expand Down
22 changes: 11 additions & 11 deletions tests/integration_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
from typing import Any
from typing import List

import openai
import openai.error
import requests
from openai import AzureOpenAI

from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment
from tests.conftest import DEFAULT_API_VERSION, TEST_SERVER_URL


def models_request_http() -> Any:
def models_request_http() -> List[str]:
response = requests.get(f"{TEST_SERVER_URL}/openai/models")
assert response.status_code == 200
return response.json()
data = response.json()["data"]
return [model["id"] for model in data]


def models_request_openai() -> Any:
return openai.Model.list(
api_type="azure",
api_base=TEST_SERVER_URL,
def models_request_openai() -> List[str]:
client = AzureOpenAI(
azure_endpoint=TEST_SERVER_URL,
api_version=DEFAULT_API_VERSION,
api_key="dummy_key",
)
data = client.models.list().data
return [model.id for model in data]


def assert_models_subset(models: Any):
actual_models = [model["id"] for model in models["data"]]
def assert_models_subset(actual_models: List[str]):
expected_models = [option.value for option in BedrockDeployment]

assert set(expected_models).issubset(
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/callback.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import LLMResult
from langchain_core.callbacks import StreamingStdOutCallbackHandler
from langchain_core.outputs import LLMResult
from typing_extensions import override

from aidial_adapter_bedrock.utils.printing import print_ai
Expand Down
34 changes: 17 additions & 17 deletions tests/utils/langchain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import re
from typing import List, Optional

from langchain.callbacks.base import Callbacks
from langchain.chat_models import AzureChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.callbacks import Callbacks
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_openai import AzureChatOpenAI

from tests.conftest import DEFAULT_API_VERSION
from tests.utils.callback import CallbackWithNewLines
Expand Down Expand Up @@ -47,28 +52,23 @@ async def run_model(
return llm_result.generations[0][-1].text


def create_model(
def create_chat_model(
base_url: str,
model_id: str,
streaming: bool,
max_tokens: Optional[int],
) -> BaseChatModel:
callbacks: Callbacks = [CallbackWithNewLines()]
return AzureChatOpenAI(
deployment_name=model_id,
azure_endpoint=base_url,
azure_deployment=model_id,
callbacks=callbacks,
openai_api_base=base_url,
openai_api_version=DEFAULT_API_VERSION,
openai_api_key="dummy_openai_api_key",
model_kwargs={
"deployment_id": model_id,
"api_key": "dummy_api_key",
},
api_version=DEFAULT_API_VERSION,
api_key="dummy_key",
verbose=True,
streaming=streaming,
max_tokens=max_tokens,
temperature=0.0,
request_timeout=10,
client=None,
temperature=0,
max_retries=0,
max_tokens=max_tokens,
request_timeout=10, # type: ignore
)

0 comments on commit 2ec8712

Please sign in to comment.