Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme committed Jan 8, 2025
1 parent a10954d commit 9145a50
Showing 1 changed file with 14 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -630,14 +630,15 @@ def test_bind_tools_tool_choice() -> None:
assert not msg.tool_calls


def test_openai_structured_output() -> None:
@pytest.mark.parametrize("model", ["gpt-4o-mini", "o1"])
def test_openai_structured_output(model: str) -> None:
class MyModel(BaseModel):
"""A Person"""

name: str
age: int

llm = ChatOpenAI(model="gpt-4o-mini").with_structured_output(MyModel)
llm = ChatOpenAI(model=model).with_structured_output(MyModel)
result = llm.invoke("I'm a 27 year old named Erick")
assert isinstance(result, MyModel)
assert result.name == "Erick"
Expand Down Expand Up @@ -820,20 +821,18 @@ class magic_function(BaseModel):


@pytest.mark.parametrize(
("model", "method", "strict"),
[("gpt-4o", "function_calling", True), ("gpt-4o-2024-08-06", "json_schema", None)],
("model", "method"),
[("gpt-4o", "function_calling"), ("gpt-4o-2024-08-06", "json_schema")],
)
def test_structured_output_strict(
model: str,
method: Literal["function_calling", "json_schema"],
strict: Optional[bool],
model: str, method: Literal["function_calling", "json_schema"]
) -> None:
"""Test to verify structured output with strict=True."""

from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper

llm = ChatOpenAI(model=model, temperature=0)
llm = ChatOpenAI(model=model)

class Joke(BaseModelProper):
"""Joke to tell user."""
Expand All @@ -842,10 +841,7 @@ class Joke(BaseModelProper):
punchline: str = FieldProper(description="answer to resolve the joke")

# Pydantic class
# Type ignoring since the interface only officially supports pydantic 1
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
# We'll need to do a pass updating the type signatures.
chat = llm.with_structured_output(Joke, method=method, strict=strict)
chat = llm.with_structured_output(Joke, method=method, strict=True)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)

Expand All @@ -854,7 +850,7 @@ class Joke(BaseModelProper):

# Schema
chat = llm.with_structured_output(
Joke.model_json_schema(), method=method, strict=strict
Joke.model_json_schema(), method=method, strict=True
)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
Expand All @@ -875,26 +871,24 @@ class InvalidJoke(BaseModelProper):
default="foo", description="answer to resolve the joke"
)

chat = llm.with_structured_output(InvalidJoke, method=method, strict=strict)
chat = llm.with_structured_output(InvalidJoke, method=method, strict=True)
with pytest.raises(openai.BadRequestError):
chat.invoke("Tell me a joke about cats.")
with pytest.raises(openai.BadRequestError):
next(chat.stream("Tell me a joke about cats."))

chat = llm.with_structured_output(
InvalidJoke.model_json_schema(), method=method, strict=strict
InvalidJoke.model_json_schema(), method=method, strict=True
)
with pytest.raises(openai.BadRequestError):
chat.invoke("Tell me a joke about cats.")
with pytest.raises(openai.BadRequestError):
next(chat.stream("Tell me a joke about cats."))


@pytest.mark.parametrize(
("model", "method", "strict"), [("gpt-4o-2024-08-06", "json_schema", None)]
)
@pytest.mark.parametrize(("model", "method"), [("gpt-4o-2024-08-06", "json_schema")])
def test_nested_structured_output_strict(
model: str, method: Literal["json_schema"], strict: Optional[bool]
model: str, method: Literal["json_schema"]
) -> None:
"""Test to verify structured output with strict=True for nested object."""

Expand All @@ -914,7 +908,7 @@ class JokeWithEvaluation(TypedDict):
self_evaluation: SelfEvaluation

# Schema
chat = llm.with_structured_output(JokeWithEvaluation, method=method, strict=strict)
chat = llm.with_structured_output(JokeWithEvaluation, method=method, strict=True)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline", "self_evaluation"}
Expand Down

0 comments on commit 9145a50

Please sign in to comment.