diff --git a/libs/langchain/langchain/chat_models/azure_openai.py b/libs/langchain/langchain/chat_models/azure_openai.py index 9e232224deb58..fe52489e10db4 100644 --- a/libs/langchain/langchain/chat_models/azure_openai.py +++ b/libs/langchain/langchain/chat_models/azure_openai.py @@ -2,10 +2,10 @@ from __future__ import annotations import logging -from typing import Any, Dict, Mapping +from typing import Any, Dict, Union -from langchain.chat_models.openai import ChatOpenAI -from langchain.pydantic_v1 import root_validator +from langchain.chat_models.openai import ChatOpenAI, _is_openai_v1 +from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.schema import ChatResult from langchain.utils import get_from_dict_or_env @@ -51,13 +51,13 @@ class AzureChatOpenAI(ChatOpenAI): in, even if not explicitly saved on this class. """ - deployment_name: str = "" + deployment_name: str = Field(default="", alias="azure_deployment") model_version: str = "" openai_api_type: str = "" - openai_api_base: str = "" - openai_api_version: str = "" - openai_api_key: str = "" - openai_organization: str = "" + openai_api_base: str = Field(default="", alias="azure_endpoint") + openai_api_version: str = Field(default="", alias="api_version") + openai_api_key: str = Field(default="", alias="api_key") + openai_organization: str = Field(default="", alias="organization") openai_proxy: str = "" @root_validator() @@ -101,14 +101,27 @@ def validate_environment(cls, values: Dict) -> Dict: "Could not import openai python package. " "Please install it with `pip install openai`." ) - try: + if _is_openai_v1(): + values["client"] = openai.AzureOpenAI( + azure_endpoint=values["openai_api_base"], + api_key=values["openai_api_key"], + timeout=values["request_timeout"], + max_retries=values["max_retries"], + organization=values["openai_organization"], + api_version=values["openai_api_version"], + azure_deployment=values["deployment_name"], + ).chat.completions + values["async_client"] = openai.AsyncAzureOpenAI( + azure_endpoint=values["openai_api_base"], + api_key=values["openai_api_key"], + timeout=values["request_timeout"], + max_retries=values["max_retries"], + organization=values["openai_organization"], + api_version=values["openai_api_version"], + azure_deployment=values["deployment_name"], + ).chat.completions + else: values["client"] = openai.ChatCompletion - except AttributeError: - raise ValueError( - "`openai` has no `ChatCompletion` attribute, this is likely " - "due to an old version of the openai package. Try upgrading it " - "with `pip install --upgrade openai`." - ) if values["n"] < 1: raise ValueError("n must be at least 1.") if values["n"] > 1 and values["streaming"]: @@ -118,10 +131,13 @@ def validate_environment(cls, values: Dict) -> Dict: @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" - return { - **super()._default_params, - "engine": self.deployment_name, - } + if _is_openai_v1(): + return super()._default_params + else: + return { + **super()._default_params, + "engine": self.deployment_name, + } @property def _identifying_params(self) -> Dict[str, Any]: @@ -131,11 +147,14 @@ def _identifying_params(self) -> Dict[str, Any]: @property def _client_params(self) -> Dict[str, Any]: """Get the config params used for the openai client.""" - return { - **super()._client_params, - "api_type": self.openai_api_type, - "api_version": self.openai_api_version, - } + if _is_openai_v1(): + return super()._client_params + else: + return { + **super()._client_params, + "api_type": self.openai_api_type, + "api_version": self.openai_api_version, + } @property def _llm_type(self) -> str: @@ -148,7 +167,9 @@ def lc_attributes(self) -> Dict[str, Any]: "openai_api_version": self.openai_api_version, } - def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: + if not isinstance(response, dict): + response = response.dict() for res in response["choices"]: if res.get("finish_reason", None) == "content_filter": raise ValueError( diff --git a/libs/langchain/langchain/chat_models/konko.py b/libs/langchain/langchain/chat_models/konko.py index aeb14c187ac3f..6c5c5ef2db50d 100644 --- a/libs/langchain/langchain/chat_models/konko.py +++ b/libs/langchain/langchain/chat_models/konko.py @@ -21,8 +21,8 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) -from langchain.chat_models.base import _generate_from_stream -from langchain.chat_models.openai import ChatOpenAI, _convert_delta_to_message_chunk +from langchain.chat_models.base import BaseChatModel, _generate_from_stream +from langchain.chat_models.openai import _convert_delta_to_message_chunk from langchain.pydantic_v1 import Field, root_validator from langchain.schema import ChatGeneration, ChatResult from langchain.schema.messages import AIMessageChunk, BaseMessage @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) -class ChatKonko(ChatOpenAI): +class ChatKonko(BaseChatModel): """`ChatKonko` Chat large language models API. To use, you should have the ``konko`` python package installed, and the diff --git a/libs/langchain/langchain/chat_models/openai.py b/libs/langchain/langchain/chat_models/openai.py index 3d80ad01e9c87..5fc9e6c41e00d 100644 --- a/libs/langchain/langchain/chat_models/openai.py +++ b/libs/langchain/langchain/chat_models/openai.py @@ -3,6 +3,7 @@ import logging import sys +from importlib.metadata import version from typing import ( TYPE_CHECKING, Any, @@ -19,6 +20,8 @@ Union, ) +from packaging.version import Version, parse + from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -44,9 +47,13 @@ ) from langchain.schema.output import ChatGenerationChunk from langchain.schema.runnable import Runnable -from langchain.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain.utils import ( + get_from_dict_or_env, + get_pydantic_field_names, +) if TYPE_CHECKING: + import httpx import tiktoken @@ -91,6 +98,9 @@ async def acompletion_with_retry( **kwargs: Any, ) -> Any: """Use tenacity to retry the async completion call.""" + if _is_openai_v1(): + return await llm.async_client.create(**kwargs) + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator @@ -108,6 +118,11 @@ def _convert_delta_to_message_chunk( content = _dict.get("content") or "" if _dict.get("function_call"): additional_kwargs = {"function_call": dict(_dict["function_call"])} + if ( + "name" in additional_kwargs["function_call"] + and additional_kwargs["function_call"]["name"] is None + ): + additional_kwargs["function_call"]["name"] = "" else: additional_kwargs = {} @@ -125,6 +140,11 @@ def _convert_delta_to_message_chunk( return default_class(content=content) +def _is_openai_v1() -> bool: + _version = parse(version("openai")) + return _version >= Version("1.0.0") + + class ChatOpenAI(BaseChatModel): """`OpenAI` Chat large language models API. @@ -166,6 +186,7 @@ def is_lc_serializable(cls) -> bool: return True client: Any = None #: :meta private: + async_client: Any = None #: :meta private: model_name: str = Field(default="gpt-3.5-turbo", alias="model") """Model name to use.""" temperature: float = 0.7 @@ -175,16 +196,18 @@ def is_lc_serializable(cls) -> bool: # When updating this to use a SecretStr # Check for classes that derive from this class (as some of them # may assume openai_api_key is a str) - openai_api_key: Optional[str] = None + openai_api_key: Optional[str] = Field(default=None, alias="api_key") """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" - openai_api_base: Optional[str] = None - openai_organization: Optional[str] = None + openai_api_base: Optional[str] = Field(default=None, alias="base_url") + openai_organization: Optional[str] = Field(default=None, alias="organization") # to support explicit proxy for OpenAI openai_proxy: Optional[str] = None - request_timeout: Optional[Union[float, Tuple[float, float]]] = None + request_timeout: Union[float, Tuple[float, float], httpx.Timeout, None] = Field( + default=None, alias="timeout" + ) """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" - max_retries: int = 6 + max_retries: int = 2 """Maximum number of retries to make when generating.""" streaming: bool = False """Whether to stream the results or not.""" @@ -266,14 +289,24 @@ def validate_environment(cls, values: Dict) -> Dict: "Could not import openai python package. " "Please install it with `pip install openai`." ) - try: + + if _is_openai_v1(): + values["client"] = openai.OpenAI( + api_key=values["openai_api_key"], + timeout=values["request_timeout"], + max_retries=values["max_retries"], + organization=values["openai_organization"], + base_url=values["openai_api_base"] or None, + ).chat.completions + values["async_client"] = openai.AsyncOpenAI( + api_key=values["openai_api_key"], + timeout=values["request_timeout"], + max_retries=values["max_retries"], + organization=values["openai_organization"], + base_url=values["openai_api_base"] or None, + ).chat.completions + else: values["client"] = openai.ChatCompletion - except AttributeError: - raise ValueError( - "`openai` has no `ChatCompletion` attribute, this is likely " - "due to an old version of the openai package. Try upgrading it " - "with `pip install --upgrade openai`." - ) if values["n"] < 1: raise ValueError("n must be at least 1.") if values["n"] > 1 and values["streaming"]: @@ -285,7 +318,6 @@ def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" return { "model": self.model_name, - "request_timeout": self.request_timeout, "max_tokens": self.max_tokens, "stream": self.streaming, "n": self.n, @@ -297,6 +329,9 @@ def completion_with_retry( self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> Any: """Use tenacity to retry the completion call.""" + if _is_openai_v1(): + return self.client.create(**kwargs) + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) @retry_decorator @@ -333,6 +368,8 @@ def _stream( for chunk in self.completion_with_retry( messages=message_dicts, run_manager=run_manager, **params ): + if not isinstance(chunk, dict): + chunk = chunk.dict() if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] @@ -381,8 +418,10 @@ def _create_message_dicts( message_dicts = [convert_message_to_dict(m) for m in messages] return message_dicts, params - def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: generations = [] + if not isinstance(response, dict): + response = response.dict() for res in response["choices"]: message = convert_dict_to_message(res["message"]) gen = ChatGeneration( @@ -408,6 +447,8 @@ async def _astream( async for chunk in await acompletion_with_retry( self, messages=message_dicts, run_manager=run_manager, **params ): + if not isinstance(chunk, dict): + chunk = chunk.dict() if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] @@ -455,11 +496,16 @@ def _identifying_params(self) -> Dict[str, Any]: def _client_params(self) -> Dict[str, Any]: """Get the parameters used for the openai client.""" openai_creds: Dict[str, Any] = { - "api_key": self.openai_api_key, - "api_base": self.openai_api_base, - "organization": self.openai_organization, "model": self.model_name, } + if not _is_openai_v1(): + openai_creds.update( + { + "api_key": self.openai_api_key, + "api_base": self.openai_api_base, + "organization": self.openai_organization, + } + ) if self.openai_proxy: import openai diff --git a/libs/langchain/tests/unit_tests/chat_models/test_azureopenai.py b/libs/langchain/tests/unit_tests/chat_models/test_azureopenai.py index 921ec0ad68bf6..fd1ec775b00c9 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_azureopenai.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_azureopenai.py @@ -1,6 +1,5 @@ import json import os -from typing import Any, Mapping, cast from unittest import mock import pytest @@ -48,9 +47,8 @@ def test_model_name_set_on_chat_result_when_present_in_response( """ # convert sample_response_text to instance of Mapping[str, Any] sample_response = json.loads(sample_response_text) - mock_response = cast(Mapping[str, Any], sample_response) mock_chat = AzureChatOpenAI() - chat_result = mock_chat._create_chat_result(mock_response) + chat_result = mock_chat._create_chat_result(sample_response) assert ( chat_result.llm_output is not None and chat_result.llm_output["model_name"] == model_name