Skip to content

Commit

Permalink
added tests to verify json feature
Browse files Browse the repository at this point in the history
  • Loading branch information
emma0925 committed Dec 11, 2024
1 parent 13184a1 commit 8eb227a
Showing 1 changed file with 65 additions and 62 deletions.
127 changes: 65 additions & 62 deletions libs/genai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,85 +516,88 @@ async def model_astream(context: str) -> List[BaseMessageChunk]:
assert isinstance(result[0], AIMessageChunk)


def test_json_formatted_output() -> None:
"""Test that json_mode works as expected with a json_schema."""

class MyModel(typing_extensions.TypedDict):
item: str
price: float
def test_output_matches_prompt_keys() -> None:
"""
Validate that when response_mime_type="application/json" is specified,
the output is a valid JSON format and contains the expected structure
based on the prompt.
"""

llm = ChatGoogleGenerativeAI(
model=_VISION_MODEL,
response_mime_type="application/json",
response_schema=list[MyModel],
)

prompt_key_names = {
"list_key": "grocery_items",
"item_key": "item",
"price_key": "price",
}

messages = [
("system", "You are a helpful assistant"),
("human", "List the prices of common grocery items"),
(
"human",
"Provide a list of grocery items with key 'grocery_items', "
"and for each item, use 'item' for the name and 'price' for the cost.",
),
]

response = llm.invoke(messages)
response_data = json.loads(response.content)
assert isinstance(response_data, list)
assert len(response.content) > 1
assert isinstance(response.content[0], MyModel)
for item in response.content:
assert isinstance(item, MyModel)


def test_json_formatted_output_with_nested_schema() -> None:
"""Test that json_mode works as expected with a nested json_schema."""

class PriceDetail(typing_extensions.TypedDict):
amount: float
currency: str

class MyModel(TypedDict):
item: str
price: PriceDetail

llm = ChatGoogleGenerativeAI(
model=_VISION_MODEL,
response_mime_type="application/json",
response_schema=list[MyModel],
)

messages = [
("system", "You are a helpful assistant"),
("human", "List the price details of a common grocery item"),
]
# Ensure the response content is a JSON string
assert isinstance(response.content, str), "Response content should be a string."

response = llm.invoke(messages)
assert isinstance(response.content, list)
assert len(response.content) > 0
assert isinstance(response.content[0], MyModel)
assert isinstance(response.content[0]["price"], PriceDetail)
# Attempt to parse the JSON
try:
response_data = json.loads(response.content)
except json.JSONDecodeError as e:
pytest.fail(f"Response is not valid JSON: {e}")

list_key = prompt_key_names["list_key"]
assert list_key in response_data, f"Expected key '{list_key}' is missing in the response."
grocery_items = response_data[list_key]
assert isinstance(grocery_items, list), f"'{list_key}' should be a list."

def test_enum_formatted_output() -> None:
"""Test that response_mime_type works as expected with text/x.enum."""
item_key = prompt_key_names["item_key"]
price_key = prompt_key_names["price_key"]

def test_enum_formatted_output() -> None:
"""Test that response_mime_type works as expected with text/x.enum."""
import enum
for item in grocery_items:
assert isinstance(item, dict), "Each grocery item should be a dictionary."
assert item_key in item, f"Each item should have the key '{item_key}'."
assert price_key in item, f"Each item should have the key '{price_key}'."

class Types(enum.Enum):
PERCUSSION = "Percussion"
STRING = "String"
WOODWIND = "Woodwind"
BRASS = "Brass"
KEYBOARD = "Keyboard"
print("Response matches the key names specified in the prompt.")

llm = ChatGoogleGenerativeAI(
model=_VISION_MODEL, response_mime_type="text/x.enum", response_schema=Types
)
def test_validate_response_mime_type_and_schema() -> None:
"""
Test that `response_mime_type` and `response_schema` are validated correctly.
Ensure valid combinations of `response_mime_type` and `response_schema` pass,
and invalid ones raise appropriate errors.
"""

messages = [
("system", "You are a helpful assistant"),
("human", "What kind of instrument is an organ?"),
]
valid_model = ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
response_mime_type="application/json",
response_schema={"type": "list", "items": {"type": "object"}}, # Example schema
)

response = llm.invoke(messages)
assert isinstance(response.content, str)
assert response.content in Types._value2member_map_
try:
valid_model.validate_environment()
except ValueError as e:
pytest.fail(f"Validation failed unexpectedly with valid parameters: {e}")

with pytest.raises(ValueError, match="response_mime_type must be either .*"):
ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
response_mime_type="invalid/type",
response_schema={"type": "list", "items": {"type": "object"}},
).validate_environment()

try:
ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
response_mime_type="application/json",
).validate_environment()
except ValueError as e:
pytest.fail(f"Validation failed unexpectedly with a valid MIME type and no schema: {e}")

0 comments on commit 8eb227a

Please sign in to comment.