diff --git a/libs/databricks/langchain_databricks/chat_models.py b/libs/databricks/langchain_databricks/chat_models.py index dcbdbed..49a214f 100644 --- a/libs/databricks/langchain_databricks/chat_models.py +++ b/libs/databricks/langchain_databricks/chat_models.py @@ -452,7 +452,9 @@ def with_structured_output( self, schema: Optional[Union[Dict, Type]] = None, *, - method: Literal["function_calling", "json_mode"] = "function_calling", + method: Literal[ + "function_calling", "json_mode", "json_schema" + ] = "function_calling", include_raw: bool = False, **kwargs: Any, ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: @@ -651,10 +653,32 @@ class AnswerWithJustification(BaseModel): if is_pydantic_schema else JsonOutputParser() ) + elif method == "json_schema": + if schema is None: + raise ValueError( + "schema must be specified when method is 'json_schema'. " + "Received None." + ) + response_format = { + "type": "json_schema", + "json_schema": { + "strict": True, + "schema": ( + schema.model_json_schema() if is_pydantic_schema else schema # type: ignore[union-attr] + ), + }, + } + llm = self.bind(response_format=response_format) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + else: raise ValueError( - f"Unrecognized method argument. Expected one of 'function_calling' or " - f"'json_mode'. Received: '{method}'" + f"Unrecognized method argument. Expected one of 'function_calling', " + f"'json_mode' or 'json_schema'. Received: '{method}'" ) if include_raw: diff --git a/libs/databricks/tests/unit_tests/test_chat_models.py b/libs/databricks/tests/unit_tests/test_chat_models.py index 0f684b3..c9c3e77 100644 --- a/libs/databricks/tests/unit_tests/test_chat_models.py +++ b/libs/databricks/tests/unit_tests/test_chat_models.py @@ -156,14 +156,14 @@ def mock_client() -> Generator: @pytest.fixture def llm() -> ChatDatabricks: return ChatDatabricks( - endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks" + endpoint="databricks-meta-llama-3-1-70b-instruct", target_uri="databricks" ) def test_dict(llm: ChatDatabricks) -> None: d = llm.dict() assert d["_type"] == "chat-databricks" - assert d["endpoint"] == "databricks-meta-llama-3-70b-instruct" + assert d["endpoint"] == "databricks-meta-llama-3-1-70b-instruct" assert d["target_uri"] == "databricks" @@ -210,7 +210,7 @@ def _assert_usage(chunk, expected): # Method 2: Pass stream_usage=True to the constructor llm_with_usage = ChatDatabricks( - endpoint="databricks-meta-llama-3-70b-instruct", + endpoint="databricks-meta-llama-3-1-70b-instruct", target_uri="databricks", stream_usage=True, ) @@ -290,15 +290,19 @@ class AnswerWithJustification(BaseModel): # Raw JSON schema JSON_SCHEMA = { "title": "AnswerWithJustification", - "description": "An answer to the user question along with justification.", + "description": ( + "An answer to the user question along with justification for the answer." + ), "type": "object", "properties": { "answer": { "type": "string", + "title": "Answer", "description": "The answer to the user question.", }, "justification": { "type": "string", + "title": "Justification", "description": "The justification for the answer.", }, }, @@ -307,9 +311,9 @@ class AnswerWithJustification(BaseModel): @pytest.mark.parametrize("schema", [AnswerWithJustification, JSON_SCHEMA, None]) -@pytest.mark.parametrize("method", ["function_calling", "json_mode"]) +@pytest.mark.parametrize("method", ["function_calling", "json_mode", "json_schema"]) def test_chat_model_with_structured_output(llm, schema, method: str): - if schema is None and method == "function_calling": + if schema is None and method in ["function_calling", "json_schema"]: pytest.skip("Cannot use function_calling without schema") structured_llm = llm.with_structured_output(schema, method=method) @@ -317,6 +321,8 @@ def test_chat_model_with_structured_output(llm, schema, method: str): bind = structured_llm.first.kwargs if method == "function_calling": assert bind["tool_choice"]["function"]["name"] == "AnswerWithJustification" + elif method == "json_schema": + assert bind["response_format"]["json_schema"]["schema"] == JSON_SCHEMA else: assert bind["response_format"] == {"type": "json_object"}