Skip to content

Commit

Permalink
standard-tests[patch]: test init from env vars (#25983)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Sep 3, 2024
1 parent ac92210 commit bc3b026
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""

from typing import Type
from typing import Tuple, Type

from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests
Expand All @@ -12,3 +12,21 @@ class TestOpenAIStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatOpenAI

@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"OPENAI_API_KEY": "api_key",
"OPENAI_ORGANIZATION": "org_id",
"OPENAI_API_BASE": "api_base",
"OPENAI_PROXY": "https://proxy.com",
},
{},
{
"openai_api_key": "api_key",
"openai_organization": "org_id",
"openai_api_base": "api_base",
"openai_proxy": "https://proxy.com",
},
)
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Unit tests for chat models."""

import os
from abc import abstractmethod
from typing import Any, List, Literal, Optional, Type
from typing import Any, List, Literal, Optional, Tuple, Type
from unittest import mock

import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool

Expand Down Expand Up @@ -132,12 +133,30 @@ def standard_chat_model_params(self) -> dict:
params["api_key"] = "test"
return params

@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {}, {}, {}

def test_init(self) -> None:
model = self.chat_model_class(
**{**self.standard_chat_model_params, **self.chat_model_params}
)
assert model is not None

def test_init_from_env(self) -> None:
env_params, model_params, expected_attrs = self.init_from_env_params
if env_params:
with mock.patch.dict(os.environ, env_params):
model = self.chat_model_class(**model_params)
assert model is not None
for k, expected in expected_attrs.items():
actual = getattr(model, k)
if isinstance(actual, SecretStr):
actual = actual.get_secret_value()
assert actual == expected

def test_init_streaming(
self,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
from abc import abstractmethod
from typing import Type
from typing import Tuple, Type
from unittest import mock

import pytest
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import SecretStr

from langchain_standard_tests.base import BaseStandardTests

Expand All @@ -26,3 +29,21 @@ class EmbeddingsUnitTests(EmbeddingsTests):
def test_init(self) -> None:
model = self.embeddings_class(**self.embedding_model_params)
assert model is not None

@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return {}, {}, {}

def test_init_from_env(self) -> None:
env_params, embeddings_params, expected_attrs = self.init_from_env_params
if env_params:
with mock.patch.dict(os.environ, env_params):
model = self.embeddings_class(**embeddings_params)
assert model is not None
for k, expected in expected_attrs.items():
actual = getattr(model, k)
if isinstance(actual, SecretStr):
actual = actual.get_secret_value()
assert actual == expected

0 comments on commit bc3b026

Please sign in to comment.