Skip to content

Commit

Permalink
core[minor], openai[minor], langchain[patch]: BaseLanguageModel.with_…
Browse files Browse the repository at this point in the history
…structured_output langchain-ai#17302)

```python
class Foo(BaseModel):
  bar: str

structured_llm = ChatOpenAI().with_structured_output(Foo)
```

---------

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
2 people authored and al1p-R committed Feb 27, 2024
1 parent 9ea08a8 commit cf152f9
Show file tree
Hide file tree
Showing 9 changed files with 448 additions and 68 deletions.
12 changes: 11 additions & 1 deletion libs/core/langchain_core/language_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,27 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
Set,
Type,
TypeVar,
Union,
)

from typing_extensions import TypeAlias

from langchain_core._api import deprecated
from langchain_core._api import beta, deprecated
from langchain_core.messages import (
AnyMessage,
BaseMessage,
MessageLikeRepresentation,
get_buffer_string,
)
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names

Expand Down Expand Up @@ -155,6 +158,13 @@ async def agenerate_prompt(
prompt and additional model provider-specific output.
"""

@beta()
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Implement this if there is a way of steering the model to generate responses that match a given schema.""" # noqa: E501
raise NotImplementedError()

@deprecated("0.1.7", alternative="invoke", removal="0.2.0")
@abstractmethod
def predict(
Expand Down
2 changes: 2 additions & 0 deletions libs/core/langchain_core/output_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MarkdownListOutputParser,
NumberedListOutputParser,
)
from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.transform import (
BaseCumulativeTransformOutputParser,
Expand All @@ -45,4 +46,5 @@
"SimpleJsonOutputParser",
"XMLOutputParser",
"JsonOutputParser",
"PydanticOutputParser",
]
8 changes: 5 additions & 3 deletions libs/core/langchain_core/output_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@

from typing_extensions import get_args

from langchain_core.language_models import LanguageModelOutput
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import run_in_executor

if TYPE_CHECKING:
from langchain_core.prompt_values import PromptValue

T = TypeVar("T")
OutputParserLike = Runnable[LanguageModelOutput, T]


class BaseLLMOutputParser(Generic[T], ABC):
Expand Down Expand Up @@ -57,7 +59,7 @@ async def aparse_result(


class BaseGenerationOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
):
"""Base class to parse the output of an LLM call."""

Expand Down Expand Up @@ -116,7 +118,7 @@ async def ainvoke(


class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
BaseLLMOutputParser, RunnableSerializable[LanguageModelOutput, T]
):
"""Base class to parse the output of an LLM call.
Expand Down
62 changes: 62 additions & 0 deletions libs/core/langchain_core/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import json
from typing import Any, List, Type

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError


class PydanticOutputParser(JsonOutputParser):
"""Parse an output using a pydantic model."""

pydantic_object: Type[BaseModel]
"""The pydantic model to parse.
Attention: To avoid potential compatibility issues, it's recommended to use
pydantic <2 or leverage the v1 namespace in pydantic >= 2.
"""

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
json_object = super().parse_result(result)
try:
return self.pydantic_object.parse_obj(json_object)
except ValidationError as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
raise OutputParserException(msg, llm_output=json_object)

def get_format_instructions(self) -> str:
# Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.schema().items()}

# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema)

return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)

@property
def _type(self) -> str:
return "pydantic"

@property
def OutputType(self) -> Type[BaseModel]:
"""Return the pydantic model."""
return self.pydantic_object


_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
Here is the output schema:
```
{schema}
```""" # noqa: E501
1 change: 1 addition & 0 deletions libs/core/tests/unit_tests/output_parsers/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"SimpleJsonOutputParser",
"XMLOutputParser",
"JsonOutputParser",
"PydanticOutputParser",
]


Expand Down
54 changes: 2 additions & 52 deletions libs/langchain/langchain/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,3 @@
import json
from typing import Any, List, Type
from langchain_core.output_parsers import PydanticOutputParser

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError

from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS


class PydanticOutputParser(JsonOutputParser):
"""Parse an output using a pydantic model."""

pydantic_object: Type[BaseModel]
"""The pydantic model to parse.
Attention: To avoid potential compatibility issues, it's recommended to use
pydantic <2 or leverage the v1 namespace in pydantic >= 2.
"""

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
json_object = super().parse_result(result)
try:
return self.pydantic_object.parse_obj(json_object)
except ValidationError as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
raise OutputParserException(msg, llm_output=json_object)

def get_format_instructions(self) -> str:
# Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.schema().items()}

# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema)

return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)

@property
def _type(self) -> str:
return "pydantic"

@property
def OutputType(self) -> Type[BaseModel]:
"""Return the pydantic model."""
return self.pydantic_object
__all__ = ["PydanticOutputParser"]
Loading

0 comments on commit cf152f9

Please sign in to comment.