Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support structured output in ChatDatabricks #28

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 228 additions & 1 deletion libs/databricks/langchain_databricks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
from operator import itemgetter
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -35,14 +36,19 @@
ToolMessageChunk,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from mlflow.deployments import BaseDeploymentClient # type: ignore
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -398,6 +404,227 @@ def bind_tools(
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)

def with_structured_output(
self,
schema: Optional[Union[Dict, Type]] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.

Assumes model is compatible with OpenAI tool-calling API.

Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OpenAI function-calling spec or be a valid JSON schema
with top level 'title' and 'description' keys specified.
method: The method for steering model generation, either "function_calling"
or "json_mode". If "function_calling" then the schema will be converted
to an OpenAI function and the returned model will make use of the
function-calling API. If "json_mode" then OpenAI's JSON mode will be
used. Note that if using "json_mode" then you must include instructions
for formatting the output into the desired schema into the model call.
include_raw: If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".

Returns:
A Runnable that takes any ChatModel input and returns as output:

If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
an instance of ``schema`` (i.e., a Pydantic object).

Otherwise, if ``include_raw`` is False then Runnable outputs a dict.

If ``include_raw`` is True, then Runnable outputs a dict with keys:
- ``"raw"``: BaseMessage
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- ``"parsing_error"``: Optional[BaseException]

Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False):
.. code-block:: python

from langchain_databricks import ChatDatabricks
from pydantic import BaseModel


class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''

answer: str
justification: str


llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct")
structured_llm = llm.with_structured_output(AnswerWithJustification)

structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)

# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )

Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True):
.. code-block:: python

from langchain_databricks import ChatDatabricks
from pydantic import BaseModel


class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''

answer: str
justification: str


llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct")
structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True
)

structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }

Example: Function-calling, dict schema (method="function_calling", include_raw=False):
.. code-block:: python

from langchain_databricks import ChatDatabricks
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel


class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''

answer: str
justification: str


dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct")
structured_llm = llm.with_structured_output(dict_schema)

structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }

Example: JSON mode, Pydantic schema (method="json_mode", include_raw=True):
.. code-block::

from langchain_databricks import ChatDatabricks
from pydantic import BaseModel

class AnswerWithJustification(BaseModel):
answer: str
justification: str

llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct")
structured_llm = llm.with_structured_output(
AnswerWithJustification,
method="json_mode",
include_raw=True
)

structured_llm.invoke(
"Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
# 'parsing_error': None
# }

Example: JSON mode, no schema (schema=None, method="json_mode", include_raw=True):
.. code-block::

structured_llm = llm.with_structured_output(method="json_mode", include_raw=True)

structured_llm.invoke(
"Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
# 'parsed': {
# 'answer': 'They are both the same weight.',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
# },
# 'parsing_error': None
# }


""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the OAI implementation, they're doing the same thing?
IIRC their tool requests are sequential, but it will be good to double check.

Copy link
Collaborator Author

@B-Step62 B-Step62 Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes https://github.com/langchain-ai/langchain/blob/4fab8996cf3a5a34bd5333c6848b0bccf798a6a0/libs/partners/openai/langchain_openai/chat_models/base.py#L1234-L1236

I think the rational is that it is guaranteed to have only one tool matches with the tool_name we pass here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM! Let's merge :)

)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
)

if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser

@property
def _identifying_params(self) -> Dict[str, Any]:
return self._default_params
Expand Down
63 changes: 61 additions & 2 deletions libs/databricks/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from langgraph.graph import START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, create_react_agent, tools_condition
from pydantic import BaseModel, Field
from typing_extensions import TypedDict

from langchain_databricks.chat_models import ChatDatabricks
Expand Down Expand Up @@ -164,8 +165,6 @@ async def test_chat_databricks_abatch():

@pytest.mark.parametrize("tool_choice", [None, "auto", "required", "any", "none"])
def test_chat_databricks_tool_calls(tool_choice):
from pydantic import BaseModel, Field

chat = ChatDatabricks(
endpoint=_TEST_ENDPOINT,
temperature=0,
Expand Down Expand Up @@ -219,6 +218,66 @@ class GetWeather(BaseModel):
]


# Pydantic-based schema
class AnswerWithJustification(BaseModel):
"""An answer to the user question along with justification for the answer."""

answer: str = Field(description="The answer to the user question.")
justification: str = Field(description="The justification for the answer.")


# Raw JSON schema
JSON_SCHEMA = {
"title": "AnswerWithJustification",
"description": "An answer to the user question along with justification.",
"type": "object",
"properties": {
"answer": {
"type": "string",
"description": "The answer to the user question.",
},
"justification": {
"type": "string",
"description": "The justification for the answer.",
},
},
"required": ["answer", "justification"],
}


@pytest.mark.parametrize("schema", [AnswerWithJustification, JSON_SCHEMA, None])
@pytest.mark.parametrize("method", ["function_calling", "json_mode"])
def test_chat_databricks_with_structured_output(schema, method):
llm = ChatDatabricks(endpoint=_TEST_ENDPOINT)

if schema is None and method == "function_calling":
pytest.skip("Cannot use function_calling without schema")

structured_llm = llm.with_structured_output(schema, method=method)

if method == "function_calling":
prompt = "What day comes two days after Monday?"
else:
prompt = (
"What day comes two days after Monday? Return in JSON format with key "
"'answer' for the answer and 'justification' for the justification."
)

response = structured_llm.invoke(prompt)

if schema == AnswerWithJustification:
assert response.answer == "Wednesday"
assert response.justification is not None
else:
assert response["answer"] == "Wednesday"
assert response["justification"] is not None

# Invoke with raw output
structured_llm = llm.with_structured_output(schema, method=method, include_raw=True)
response_with_raw = structured_llm.invoke(prompt)
assert isinstance(response_with_raw["raw"], AIMessage)


def test_chat_databricks_runnable_sequence():
chat = ChatDatabricks(
endpoint=_TEST_ENDPOINT,
Expand Down
Loading
Loading