From d85ece9fe3e7c8921ae3fb0026ef9c329f45d068 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 29 Oct 2024 19:35:51 -0700 Subject: [PATCH 1/2] rfc: AIMessage.parsed and with_structured_output(..., tools=[]) --- libs/core/langchain_core/messages/ai.py | 15 ++++- .../langchain_core/output_parsers/base.py | 44 ++++++++++---- .../langchain_openai/chat_models/base.py | 58 +++++++++++++++---- 3 files changed, 94 insertions(+), 23 deletions(-) diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index 727a0045ffba5..99a6ded71c713 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -2,7 +2,7 @@ import operator from typing import Any, Literal, Optional, Union, cast -from pydantic import model_validator +from pydantic import BaseModel, model_validator from typing_extensions import NotRequired, Self, TypedDict from langchain_core.messages.base import ( @@ -166,6 +166,7 @@ class AIMessage(BaseMessage): type: Literal["ai"] = "ai" """The type of the message (used for deserialization). Defaults to "ai".""" + parsed: Optional[Union[dict, BaseModel]] = None def __init__( self, content: Union[str, list[Union[str, dict]]], **kwargs: Any @@ -440,6 +441,17 @@ def add_ai_message_chunks( else: usage_metadata = None + has_parsed = [m for m in ([left, *others]) if m.parsed] + if len(has_parsed) >= 2: + msg = ( + "Cannot concatenate two AIMessageChunks with non-null 'parsed' attributes." + ) + raise ValueError(msg) + elif len(has_parsed) == 1: + parsed = has_parsed[0].parsed + else: + parsed = None + return left.__class__( example=left.example, content=content, @@ -448,6 +460,7 @@ def add_ai_message_chunks( response_metadata=response_metadata, usage_metadata=usage_metadata, id=left.id, + parsed=parsed, ) diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 9d080cef300bc..1b8270538bdac 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -61,10 +61,12 @@ async def aparse_result( class BaseGenerationOutputParser( - BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T] + BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, Union[AnyMessage, T]] ): """Base class to parse the output of an LLM call.""" + return_message: bool = False + @property @override def InputType(self) -> Any: @@ -73,11 +75,14 @@ def InputType(self) -> Any: @property @override - def OutputType(self) -> type[T]: + def OutputType(self) -> Union[type[AnyMessage], type[T]]: """Return the output type for the parser.""" - # even though mypy complains this isn't valid, - # it is good enough for pydantic to build the schema from - return T # type: ignore[misc] + if self.return_message: + return AnyMessage + else: + # even though mypy complains this isn't valid, + # it is good enough for pydantic to build the schema from + return T # type: ignore[misc] def invoke( self, @@ -86,7 +91,7 @@ def invoke( **kwargs: Any, ) -> T: if isinstance(input, BaseMessage): - return self._call_with_config( + parsed = self._call_with_config( lambda inner_input: self.parse_result( [ChatGeneration(message=inner_input)] ), @@ -94,6 +99,8 @@ def invoke( config, run_type="parser", ) + if self.return_message: + return input.model_copy(update={"parsed": parsed}) else: return self._call_with_config( lambda inner_input: self.parse_result([Generation(text=inner_input)]), @@ -109,7 +116,7 @@ async def ainvoke( **kwargs: Optional[Any], ) -> T: if isinstance(input, BaseMessage): - return await self._acall_with_config( + parsed = await self._acall_with_config( lambda inner_input: self.aparse_result( [ChatGeneration(message=inner_input)] ), @@ -117,6 +124,8 @@ async def ainvoke( config, run_type="parser", ) + if self.return_message: + return input.model_copy(update={"parsed": parsed}) else: return await self._acall_with_config( lambda inner_input: self.aparse_result([Generation(text=inner_input)]), @@ -127,7 +136,7 @@ async def ainvoke( class BaseOutputParser( - BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T] + BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, Union[AnyMessage, T]] ): """Base class to parse the output of an LLM call. @@ -155,6 +164,8 @@ def _type(self) -> str: return "boolean_output_parser" """ # noqa: E501 + return_message: bool = False + @property @override def InputType(self) -> Any: @@ -163,7 +174,7 @@ def InputType(self) -> Any: @property @override - def OutputType(self) -> type[T]: + def OutputType(self) -> Union[type[AnyMessage], type[T]]: """Return the output type for the parser. This property is inferred from the first type argument of the class. @@ -171,6 +182,9 @@ def OutputType(self) -> type[T]: Raises: TypeError: If the class doesn't have an inferable OutputType. """ + if self.return_message: + return AnyMessage + for base in self.__class__.mro(): if hasattr(base, "__pydantic_generic_metadata__"): metadata = base.__pydantic_generic_metadata__ @@ -190,7 +204,7 @@ def invoke( **kwargs: Any, ) -> T: if isinstance(input, BaseMessage): - return self._call_with_config( + parsed = self._call_with_config( lambda inner_input: self.parse_result( [ChatGeneration(message=inner_input)] ), @@ -198,6 +212,10 @@ def invoke( config, run_type="parser", ) + if self.return_message: + return input.model_copy(update={"parsed": parsed}) + else: + return parsed else: return self._call_with_config( lambda inner_input: self.parse_result([Generation(text=inner_input)]), @@ -213,7 +231,7 @@ async def ainvoke( **kwargs: Optional[Any], ) -> T: if isinstance(input, BaseMessage): - return await self._acall_with_config( + parsed = await self._acall_with_config( lambda inner_input: self.aparse_result( [ChatGeneration(message=inner_input)] ), @@ -221,6 +239,10 @@ async def ainvoke( config, run_type="parser", ) + if self.return_message: + return input.model_copy(update={"parsed": parsed}) + else: + return parsed else: return await self._acall_with_config( lambda inner_input: self.aparse_result([Generation(text=inner_input)]), diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 7ad8586296005..f9a0e9ad5dfed 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -767,6 +767,7 @@ def _create_chat_result( message = response.choices[0].message # type: ignore[attr-defined] if hasattr(message, "parsed"): generations[0].message.additional_kwargs["parsed"] = message.parsed + cast(AIMessage, generations[0].message).parsed = message.parsed if hasattr(message, "refusal"): generations[0].message.additional_kwargs["refusal"] = message.refusal @@ -1144,10 +1145,18 @@ def with_structured_output( method: Literal[ "function_calling", "json_mode", "json_schema" ] = "function_calling", - include_raw: bool = False, + include_raw: Union[ + bool, Literal["raw_only", "parsed_only", "raw_and_parsed"] + ] = False, strict: Optional[bool] = None, + tools: Optional[ + Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]] + ] = None, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ] = None, **kwargs: Any, - ) -> Runnable[LanguageModelInput, _DictOrPydantic]: + ) -> Runnable[LanguageModelInput, Union[_DictOrPydantic, BaseMessage]]: """Model wrapper that returns outputs formatted to match the given schema. Args: @@ -1432,12 +1441,19 @@ class AnswerWithJustification(BaseModel): "schema must be specified when method is not 'json_mode'. " "Received None." ) - tool_name = convert_to_openai_tool(schema)["function"]["name"] - bind_kwargs = self._filter_disabled_params( - tool_choice=tool_name, parallel_tool_calls=False, strict=strict - ) + if not tools: + tool_name = convert_to_openai_tool(schema)["function"]["name"] + bind_kwargs = self._filter_disabled_params( + tool_choice=tool_name, parallel_tool_calls=False, strict=strict + ) + + llm = self.bind_tools([schema], **bind_kwargs) + else: + bind_kwargs = self._filter_disabled_params( + strict=strict, tool_choice=tool_choice + ) + llm = self.bind_tools([schema, *tools], **bind_kwargs) - llm = self.bind_tools([schema], **bind_kwargs) if is_pydantic_schema: output_parser: Runnable = PydanticToolsParser( tools=[schema], # type: ignore[list-item] @@ -1448,7 +1464,15 @@ class AnswerWithJustification(BaseModel): key_name=tool_name, first_tool_only=True ) elif method == "json_mode": - llm = self.bind(response_format={"type": "json_object"}) + if not tools: + llm = self.bind(response_format={"type": "json_object"}) + else: + bind_kwargs = self._filter_disabled_params( + strict=strict, + tool_choice=tool_choice, + response_format={"type": "json_object"}, + ) + llm = self.bind_tools(tools, **bind_kwargs) output_parser = ( PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] if is_pydantic_schema @@ -1461,7 +1485,15 @@ class AnswerWithJustification(BaseModel): "Received None." ) response_format = _convert_to_openai_response_format(schema, strict=strict) - llm = self.bind(response_format=response_format) + if not tools: + llm = self.bind(response_format=response_format) + else: + bind_kwargs = self._filter_disabled_params( + strict=strict, + tool_choice=tool_choice, + response_format=response_format, + ) + llm = self.bind_tools(tools, **bind_kwargs) if is_pydantic_schema: output_parser = _oai_structured_outputs_parser.with_types( output_type=cast(type, schema) @@ -1474,7 +1506,7 @@ class AnswerWithJustification(BaseModel): f"'json_mode'. Received: '{method}'" ) - if include_raw: + if include_raw is True or include_raw == "raw_and_parsed": parser_assign = RunnablePassthrough.assign( parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None ) @@ -1483,6 +1515,8 @@ class AnswerWithJustification(BaseModel): [parser_none], exception_key="parsing_error" ) return RunnableMap(raw=llm) | parser_with_fallback + elif include_raw == "raw_only": + return llm else: return llm | output_parser @@ -2174,7 +2208,9 @@ def _convert_to_openai_response_format( @chain def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel: - if ai_msg.additional_kwargs.get("parsed"): + if ai_msg.parsed: + return cast(PydanticBaseModel, ai_msg.parsed) + elif ai_msg.additional_kwargs.get("parsed"): return ai_msg.additional_kwargs["parsed"] elif ai_msg.additional_kwargs.get("refusal"): raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"]) From 2957f8eebd83ce5a29df5d1669bf07d4481e054f Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 29 Oct 2024 21:53:28 -0700 Subject: [PATCH 2/2] wip --- libs/partners/openai/langchain_openai/chat_models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index f9a0e9ad5dfed..31dfa0ea2192f 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -1489,7 +1489,7 @@ class AnswerWithJustification(BaseModel): llm = self.bind(response_format=response_format) else: bind_kwargs = self._filter_disabled_params( - strict=strict, + strict=True, tool_choice=tool_choice, response_format=response_format, )