Skip to content

Commit

Permalink
Add support for response_format as json schema
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipMathieuIDEXX committed Dec 2, 2024
1 parent 0f4c2ef commit 518b68b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
30 changes: 27 additions & 3 deletions libs/databricks/langchain_databricks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,9 @@ def with_structured_output(
self,
schema: Optional[Union[Dict, Type]] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
Expand Down Expand Up @@ -651,10 +653,32 @@ class AnswerWithJustification(BaseModel):
if is_pydantic_schema
else JsonOutputParser()
)
elif method == "json_schema":
if schema is None:
raise ValueError(
"schema must be specified when method is 'json_schema'. "
"Received None."
)
response_format = {
"type": "json_schema",
"json_schema": {
"strict": True,
"schema": (
schema.model_json_schema() if is_pydantic_schema else schema # type: ignore[union-attr]
),
},
}
llm = self.bind(response_format=response_format)
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema
else JsonOutputParser()
)

else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
f"Unrecognized method argument. Expected one of 'function_calling', "
f"'json_mode' or 'json_schema'. Received: '{method}'"
)

if include_raw:
Expand Down
18 changes: 12 additions & 6 deletions libs/databricks/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ def mock_client() -> Generator:
@pytest.fixture
def llm() -> ChatDatabricks:
return ChatDatabricks(
endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
endpoint="databricks-meta-llama-3-1-70b-instruct", target_uri="databricks"
)


def test_dict(llm: ChatDatabricks) -> None:
d = llm.dict()
assert d["_type"] == "chat-databricks"
assert d["endpoint"] == "databricks-meta-llama-3-70b-instruct"
assert d["endpoint"] == "databricks-meta-llama-3-1-70b-instruct"
assert d["target_uri"] == "databricks"


Expand Down Expand Up @@ -210,7 +210,7 @@ def _assert_usage(chunk, expected):

# Method 2: Pass stream_usage=True to the constructor
llm_with_usage = ChatDatabricks(
endpoint="databricks-meta-llama-3-70b-instruct",
endpoint="databricks-meta-llama-3-1-70b-instruct",
target_uri="databricks",
stream_usage=True,
)
Expand Down Expand Up @@ -290,15 +290,19 @@ class AnswerWithJustification(BaseModel):
# Raw JSON schema
JSON_SCHEMA = {
"title": "AnswerWithJustification",
"description": "An answer to the user question along with justification.",
"description": (
"An answer to the user question along with justification for the answer."
),
"type": "object",
"properties": {
"answer": {
"type": "string",
"title": "Answer",
"description": "The answer to the user question.",
},
"justification": {
"type": "string",
"title": "Justification",
"description": "The justification for the answer.",
},
},
Expand All @@ -307,16 +311,18 @@ class AnswerWithJustification(BaseModel):


@pytest.mark.parametrize("schema", [AnswerWithJustification, JSON_SCHEMA, None])
@pytest.mark.parametrize("method", ["function_calling", "json_mode"])
@pytest.mark.parametrize("method", ["function_calling", "json_mode", "json_schema"])
def test_chat_model_with_structured_output(llm, schema, method: str):
if schema is None and method == "function_calling":
if schema is None and method in ["function_calling", "json_schema"]:
pytest.skip("Cannot use function_calling without schema")

structured_llm = llm.with_structured_output(schema, method=method)

bind = structured_llm.first.kwargs
if method == "function_calling":
assert bind["tool_choice"]["function"]["name"] == "AnswerWithJustification"
elif method == "json_schema":
assert bind["response_format"]["json_schema"]["schema"] == JSON_SCHEMA
else:
assert bind["response_format"] == {"type": "json_object"}

Expand Down

0 comments on commit 518b68b

Please sign in to comment.