From 484947c49228e35e47d0bf66f5ecaf887bce3011 Mon Sep 17 00:00:00 2001 From: David Duong Date: Fri, 6 Oct 2023 23:43:29 +0200 Subject: [PATCH] Fetch up-to-date attributes for env-pulled kwargs during serialisation of OpenAI classes (#11499) --- .../langchain/chat_models/azure_openai.py | 7 +++++++ .../langchain/langchain/chat_models/openai.py | 15 +++++++++++++ libs/langchain/langchain/llms/openai.py | 21 +++++++++++++++++++ .../chat_models/test_azureopenai.py | 13 ++++++++---- 4 files changed, 52 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/chat_models/azure_openai.py b/libs/langchain/langchain/chat_models/azure_openai.py index c5090fd845189..9e232224deb58 100644 --- a/libs/langchain/langchain/chat_models/azure_openai.py +++ b/libs/langchain/langchain/chat_models/azure_openai.py @@ -141,6 +141,13 @@ def _client_params(self) -> Dict[str, Any]: def _llm_type(self) -> str: return "azure-openai-chat" + @property + def lc_attributes(self) -> Dict[str, Any]: + return { + "openai_api_type": self.openai_api_type, + "openai_api_version": self.openai_api_version, + } + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: for res in response["choices"]: if res.get("finish_reason", None) == "content_filter": diff --git a/libs/langchain/langchain/chat_models/openai.py b/libs/langchain/langchain/chat_models/openai.py index f22b55a1030c0..a09dfd3e9cbeb 100644 --- a/libs/langchain/langchain/chat_models/openai.py +++ b/libs/langchain/langchain/chat_models/openai.py @@ -141,6 +141,21 @@ class ChatOpenAI(BaseChatModel): def lc_secrets(self) -> Dict[str, str]: return {"openai_api_key": "OPENAI_API_KEY"} + @property + def lc_attributes(self) -> Dict[str, Any]: + attributes: Dict[str, Any] = {} + + if self.openai_organization != "": + attributes["openai_organization"] = self.openai_organization + + if self.openai_api_base != "": + attributes["openai_api_base"] = self.openai_api_base + + if self.openai_proxy != "": + attributes["openai_proxy"] = self.openai_proxy + + return attributes + @classmethod def is_lc_serializable(cls) -> bool: """Return whether this model can be serialized by Langchain.""" diff --git a/libs/langchain/langchain/llms/openai.py b/libs/langchain/langchain/llms/openai.py index 462dfe4a6d4d0..7ff51f442cf8c 100644 --- a/libs/langchain/langchain/llms/openai.py +++ b/libs/langchain/langchain/llms/openai.py @@ -138,6 +138,20 @@ class BaseOpenAI(BaseLLM): def lc_secrets(self) -> Dict[str, str]: return {"openai_api_key": "OPENAI_API_KEY"} + @property + def lc_attributes(self) -> Dict[str, Any]: + attributes: Dict[str, Any] = {} + if self.openai_api_base != "": + attributes["openai_api_base"] = self.openai_api_base + + if self.openai_organization != "": + attributes["openai_organization"] = self.openai_organization + + if self.openai_proxy != "": + attributes["openai_proxy"] = self.openai_proxy + + return attributes + @classmethod def is_lc_serializable(cls) -> bool: return True @@ -692,6 +706,13 @@ def _llm_type(self) -> str: """Return type of llm.""" return "azure" + @property + def lc_attributes(self) -> Dict[str, Any]: + return { + "openai_api_type": self.openai_api_type, + "openai_api_version": self.openai_api_version, + } + class OpenAIChat(BaseLLM): """OpenAI Chat large language models. diff --git a/libs/langchain/tests/unit_tests/chat_models/test_azureopenai.py b/libs/langchain/tests/unit_tests/chat_models/test_azureopenai.py index 94a8e6d44dc52..921ec0ad68bf6 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_azureopenai.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_azureopenai.py @@ -1,16 +1,21 @@ import json import os from typing import Any, Mapping, cast +from unittest import mock import pytest from langchain.chat_models.azure_openai import AzureChatOpenAI -os.environ["OPENAI_API_KEY"] = "test" -os.environ["OPENAI_API_BASE"] = "https://oai.azure.com/" -os.environ["OPENAI_API_VERSION"] = "2023-05-01" - +@mock.patch.dict( + os.environ, + { + "OPENAI_API_KEY": "test", + "OPENAI_API_BASE": "https://oai.azure.com/", + "OPENAI_API_VERSION": "2023-05-01", + }, +) @pytest.mark.requires("openai") @pytest.mark.parametrize( "model_name", ["gpt-4", "gpt-4-32k", "gpt-35-turbo", "gpt-35-turbo-16k"]