From 9d2d45a441e644ea94863dfb7b2883092b095ce5 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Fri, 16 Aug 2024 11:06:53 -0700 Subject: [PATCH] json mode standard test --- .../tests/integration_tests/test_standard.py | 4 ++ .../tests/integration_tests/test_standard.py | 4 ++ .../tests/integration_tests/test_standard.py | 4 ++ .../chat_models/test_azure_standard.py | 8 ++++ .../chat_models/test_base_standard.py | 4 ++ .../integration_tests/chat_models.py | 41 +++++++++++++++++++ .../unit_tests/chat_models.py | 4 ++ 7 files changed, 69 insertions(+) diff --git a/libs/partners/fireworks/tests/integration_tests/test_standard.py b/libs/partners/fireworks/tests/integration_tests/test_standard.py index cfefb2445e6a3..3f5f6de25e643 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_standard.py +++ b/libs/partners/fireworks/tests/integration_tests/test_standard.py @@ -26,3 +26,7 @@ def chat_model_params(self) -> dict: @pytest.mark.xfail(reason="Not yet implemented.") def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None: super().test_tool_message_histories_list_content(model) + + @property + def supports_json_mode(self) -> bool: + return True diff --git a/libs/partners/groq/tests/integration_tests/test_standard.py b/libs/partners/groq/tests/integration_tests/test_standard.py index e701c726f1d7b..58c2852bbfa61 100644 --- a/libs/partners/groq/tests/integration_tests/test_standard.py +++ b/libs/partners/groq/tests/integration_tests/test_standard.py @@ -23,6 +23,10 @@ def chat_model_class(self) -> Type[BaseChatModel]: def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None: super().test_tool_message_histories_list_content(model) + @property + def supports_json_mode(self) -> bool: + return True + class TestGroqLlama(BaseTestGroq): @property diff --git a/libs/partners/mistralai/tests/integration_tests/test_standard.py b/libs/partners/mistralai/tests/integration_tests/test_standard.py index 965cd03c4b178..2f47f669a80fd 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_standard.py +++ b/libs/partners/mistralai/tests/integration_tests/test_standard.py @@ -18,3 +18,7 @@ def chat_model_class(self) -> Type[BaseChatModel]: @property def chat_model_params(self) -> dict: return {"model": "mistral-large-latest", "temperature": 0} + + @property + def supports_json_mode(self) -> bool: + return True diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py index 41c20a942ac52..9e636fc182150 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py @@ -35,3 +35,11 @@ def chat_model_params(self) -> dict: @pytest.mark.xfail(reason="Not yet supported.") def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: super().test_usage_metadata_streaming(model) + + @property + def supports_image_inputs(self) -> bool: + return True + + @property + def supports_json_mode(self) -> bool: + return True diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py index ddd33952c0c34..070f05067645a 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py @@ -20,3 +20,7 @@ def chat_model_params(self) -> dict: @property def supports_image_inputs(self) -> bool: return True + + @property + def supports_json_mode(self) -> bool: + return True diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index bcb47a4c151a7..4d367c5c70fb7 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -509,3 +509,44 @@ def test_tool_message_error_status(self, model: BaseChatModel) -> None: ] result = model_with_tools.invoke(messages) assert isinstance(result, AIMessage) + + def test_json_mode(self, model: BaseChatModel) -> None: + if not self.supports_json_mode: + pytest.skip("Test requires json mode support.") + + from pydantic import BaseModel as BaseModelProper + from pydantic import Field as FieldProper + + class Joke(BaseModelProper): + """Joke to tell user.""" + + setup: str = FieldProper(description="question to set up a joke") + punchline: str = FieldProper(description="answer to resolve the joke") + + # Pydantic class + # Type ignoring since the interface only officially supports pydantic 1 + # or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2. + # We'll need to do a pass updating the type signatures. + chat = model.with_structured_output(Joke, method="json_mode") # type: ignore[arg-type] + msg = ( + "Tell me a joke about cats. Return the result as a JSON with 'setup' and " + "'punchline' keys. Return nothing other than JSON." + ) + result = chat.invoke(msg) + assert isinstance(result, Joke) + + for chunk in chat.stream(msg): + assert isinstance(chunk, Joke) + + # Schema + chat = model.with_structured_output( + Joke.model_json_schema(), method="json_mode" + ) + result = chat.invoke(msg) + assert isinstance(result, dict) + assert set(result.keys()) == {"setup", "punchline"} + + for chunk in chat.stream(msg): + assert isinstance(chunk, dict) + assert isinstance(chunk, dict) # for mypy + assert set(chunk.keys()) == {"setup", "punchline"} diff --git a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py index ed73771dbdae0..c4a1cc531b3e0 100644 --- a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py @@ -119,6 +119,10 @@ def returns_usage_metadata(self) -> bool: def supports_anthropic_inputs(self) -> bool: return False + @property + def supports_json_mode(self) -> bool: + return False + class ChatModelUnitTests(ChatModelTests): @property