Skip to content

Commit

Permalink
Fetch up-to-date attributes for env-pulled kwargs during serialisatio…
Browse files Browse the repository at this point in the history
…n of OpenAI classes (#11499)
  • Loading branch information
dqbd authored Oct 6, 2023
1 parent c3d2b01 commit 484947c
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 4 deletions.
7 changes: 7 additions & 0 deletions libs/langchain/langchain/chat_models/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ def _client_params(self) -> Dict[str, Any]:
def _llm_type(self) -> str:
return "azure-openai-chat"

@property
def lc_attributes(self) -> Dict[str, Any]:
return {
"openai_api_type": self.openai_api_type,
"openai_api_version": self.openai_api_version,
}

def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
for res in response["choices"]:
if res.get("finish_reason", None) == "content_filter":
Expand Down
15 changes: 15 additions & 0 deletions libs/langchain/langchain/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ class ChatOpenAI(BaseChatModel):
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}

@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}

if self.openai_organization != "":
attributes["openai_organization"] = self.openai_organization

if self.openai_api_base != "":
attributes["openai_api_base"] = self.openai_api_base

if self.openai_proxy != "":
attributes["openai_proxy"] = self.openai_proxy

return attributes

@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
Expand Down
21 changes: 21 additions & 0 deletions libs/langchain/langchain/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,20 @@ class BaseOpenAI(BaseLLM):
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}

@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.openai_api_base != "":
attributes["openai_api_base"] = self.openai_api_base

if self.openai_organization != "":
attributes["openai_organization"] = self.openai_organization

if self.openai_proxy != "":
attributes["openai_proxy"] = self.openai_proxy

return attributes

@classmethod
def is_lc_serializable(cls) -> bool:
return True
Expand Down Expand Up @@ -692,6 +706,13 @@ def _llm_type(self) -> str:
"""Return type of llm."""
return "azure"

@property
def lc_attributes(self) -> Dict[str, Any]:
return {
"openai_api_type": self.openai_api_type,
"openai_api_version": self.openai_api_version,
}


class OpenAIChat(BaseLLM):
"""OpenAI Chat large language models.
Expand Down
13 changes: 9 additions & 4 deletions libs/langchain/tests/unit_tests/chat_models/test_azureopenai.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import json
import os
from typing import Any, Mapping, cast
from unittest import mock

import pytest

from langchain.chat_models.azure_openai import AzureChatOpenAI

os.environ["OPENAI_API_KEY"] = "test"
os.environ["OPENAI_API_BASE"] = "https://oai.azure.com/"
os.environ["OPENAI_API_VERSION"] = "2023-05-01"


@mock.patch.dict(
os.environ,
{
"OPENAI_API_KEY": "test",
"OPENAI_API_BASE": "https://oai.azure.com/",
"OPENAI_API_VERSION": "2023-05-01",
},
)
@pytest.mark.requires("openai")
@pytest.mark.parametrize(
"model_name", ["gpt-4", "gpt-4-32k", "gpt-35-turbo", "gpt-35-turbo-16k"]
Expand Down

0 comments on commit 484947c

Please sign in to comment.