Skip to content

Commit

Permalink
openai[patch]: fix azure open lc serialization, release 0.1.5 (#21159)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored May 1, 2024
1 parent 94a8387 commit 6fa8626
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 67 deletions.
15 changes: 13 additions & 2 deletions libs/partners/openai/langchain_openai/chat_models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env

from langchain_openai.chat_models.base import ChatOpenAI
from langchain_openai.chat_models.base import BaseChatOpenAI

logger = logging.getLogger(__name__)


class AzureChatOpenAI(ChatOpenAI):
class AzureChatOpenAI(BaseChatOpenAI):
"""`Azure OpenAI` Chat Completion API.
To use this class you
Expand Down Expand Up @@ -100,6 +100,17 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "azure_openai"]

@property
def lc_secrets(self) -> Dict[str, str]:
return {
"openai_api_key": "AZURE_OPENAI_API_KEY",
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
}

@classmethod
def is_lc_serializable(cls) -> bool:
return True

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
Expand Down
94 changes: 48 additions & 46 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,52 +291,7 @@ class _AllReturnType(TypedDict):
parsing_error: Optional[BaseException]


class ChatOpenAI(BaseChatModel):
"""`OpenAI` Chat large language models API.
To use, you should have the environment variable ``OPENAI_API_KEY``
set with your API key, or pass it as a named parameter to the constructor.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-3.5-turbo")
"""

@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "openai"]

@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}

if self.openai_organization:
attributes["openai_organization"] = self.openai_organization

if self.openai_api_base:
attributes["openai_api_base"] = self.openai_api_base

if self.openai_proxy:
attributes["openai_proxy"] = self.openai_proxy

return attributes

@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True

class BaseChatOpenAI(BaseChatModel):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
Expand Down Expand Up @@ -1093,6 +1048,53 @@ class AnswerWithJustification(BaseModel):
return llm | output_parser


class ChatOpenAI(BaseChatOpenAI):
"""`OpenAI` Chat large language models API.
To use, you should have the environment variable ``OPENAI_API_KEY``
set with your API key, or pass it as a named parameter to the constructor.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-3.5-turbo")
"""

@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "openai"]

@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}

if self.openai_organization:
attributes["openai_organization"] = self.openai_organization

if self.openai_api_base:
attributes["openai_api_base"] = self.openai_api_base

if self.openai_proxy:
attributes["openai_proxy"] = self.openai_proxy

return attributes

@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True


def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)

Expand Down
12 changes: 12 additions & 0 deletions libs/partners/openai/langchain_openai/llms/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "openai"]

@property
def lc_secrets(self) -> Dict[str, str]:
return {
"openai_api_key": "AZURE_OPENAI_API_KEY",
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
}

@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
Expand Down
36 changes: 18 additions & 18 deletions libs/partners/openai/langchain_openai/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,6 @@ def _stream_response_to_generation_chunk(
class BaseOpenAI(BaseLLM):
"""Base OpenAI large language model class."""

@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}

@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.openai_api_base:
attributes["openai_api_base"] = self.openai_api_base

if self.openai_organization:
attributes["openai_organization"] = self.openai_organization

if self.openai_proxy:
attributes["openai_proxy"] = self.openai_proxy

return attributes

client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo-instruct", alias="model")
Expand Down Expand Up @@ -649,3 +631,21 @@ def is_lc_serializable(cls) -> bool:
@property
def _invocation_params(self) -> Dict[str, Any]:
return {**{"model": self.model_name}, **super()._invocation_params}

@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}

@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.openai_api_base:
attributes["openai_api_base"] = self.openai_api_base

if self.openai_organization:
attributes["openai_organization"] = self.openai_organization

if self.openai_proxy:
attributes["openai_proxy"] = self.openai_proxy

return attributes
2 changes: 1 addition & 1 deletion libs/partners/openai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-openai"
version = "0.1.4"
version = "0.1.5"
description = "An integration package connecting OpenAI and LangChain"
authors = []
readme = "README.md"
Expand Down
17 changes: 17 additions & 0 deletions libs/partners/openai/tests/unit_tests/test_secrets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Type, cast

import pytest
from langchain_core.load import dumpd
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch

Expand Down Expand Up @@ -187,3 +188,19 @@ def test_openai_uses_actual_secret_value_from_secretstr(model_class: Type) -> No
"""Test that the actual secret value is correctly retrieved."""
model = model_class(openai_api_key="secret-api-key")
assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key"


@pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI])
def test_azure_serialized_secrets(model_class: Type) -> None:
"""Test that the actual secret value is correctly retrieved."""
model = model_class(
openai_api_key="secret-api-key", api_version="foo", azure_endpoint="foo"
)
serialized = dumpd(model)
assert serialized["kwargs"]["openai_api_key"]["id"] == ["AZURE_OPENAI_API_KEY"]

model = model_class(
azure_ad_token="secret-token", api_version="foo", azure_endpoint="foo"
)
serialized = dumpd(model)
assert serialized["kwargs"]["azure_ad_token"]["id"] == ["AZURE_OPENAI_AD_TOKEN"]

0 comments on commit 6fa8626

Please sign in to comment.