diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py index 34198d0fb958c..03a1fc734afb4 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py @@ -1,6 +1,6 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Tuple, Type from langchain_core.language_models import BaseChatModel from langchain_standard_tests.unit_tests import ChatModelUnitTests @@ -12,3 +12,21 @@ class TestOpenAIStandard(ChatModelUnitTests): @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatOpenAI + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + return ( + { + "OPENAI_API_KEY": "api_key", + "OPENAI_ORGANIZATION": "org_id", + "OPENAI_API_BASE": "api_base", + "OPENAI_PROXY": "https://proxy.com", + }, + {}, + { + "openai_api_key": "api_key", + "openai_organization": "org_id", + "openai_api_base": "api_base", + "openai_proxy": "https://proxy.com", + }, + ) diff --git a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py index 6597b16177be4..f3a49b6c341ec 100644 --- a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py @@ -1,11 +1,12 @@ """Unit tests for chat models.""" - +import os from abc import abstractmethod -from typing import Any, List, Literal, Optional, Type +from typing import Any, List, Literal, Optional, Tuple, Type +from unittest import mock import pytest from langchain_core.language_models import BaseChatModel -from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr from langchain_core.runnables import RunnableBinding from langchain_core.tools import tool @@ -132,12 +133,30 @@ def standard_chat_model_params(self) -> dict: params["api_key"] = "test" return params + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + """Return env vars, init args, and expected instance attrs for initializing + from env vars.""" + return {}, {}, {} + def test_init(self) -> None: model = self.chat_model_class( **{**self.standard_chat_model_params, **self.chat_model_params} ) assert model is not None + def test_init_from_env(self) -> None: + env_params, model_params, expected_attrs = self.init_from_env_params + if env_params: + with mock.patch.dict(os.environ, env_params): + model = self.chat_model_class(**model_params) + assert model is not None + for k, expected in expected_attrs.items(): + actual = getattr(model, k) + if isinstance(actual, SecretStr): + actual = actual.get_secret_value() + assert actual == expected + def test_init_streaming( self, ) -> None: diff --git a/libs/standard-tests/langchain_standard_tests/unit_tests/embeddings.py b/libs/standard-tests/langchain_standard_tests/unit_tests/embeddings.py index 28e8c66bbd6fd..0a6e793c06327 100644 --- a/libs/standard-tests/langchain_standard_tests/unit_tests/embeddings.py +++ b/libs/standard-tests/langchain_standard_tests/unit_tests/embeddings.py @@ -1,8 +1,11 @@ +import os from abc import abstractmethod -from typing import Type +from typing import Tuple, Type +from unittest import mock import pytest from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import SecretStr from langchain_standard_tests.base import BaseStandardTests @@ -26,3 +29,21 @@ class EmbeddingsUnitTests(EmbeddingsTests): def test_init(self) -> None: model = self.embeddings_class(**self.embedding_model_params) assert model is not None + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + """Return env vars, init args, and expected instance attrs for initializing + from env vars.""" + return {}, {}, {} + + def test_init_from_env(self) -> None: + env_params, embeddings_params, expected_attrs = self.init_from_env_params + if env_params: + with mock.patch.dict(os.environ, env_params): + model = self.embeddings_class(**embeddings_params) + assert model is not None + for k, expected in expected_attrs.items(): + actual = getattr(model, k) + if isinstance(actual, SecretStr): + actual = actual.get_secret_value() + assert actual == expected