diff --git a/libs/genai/tests/integration_tests/test_chat_models.py b/libs/genai/tests/integration_tests/test_chat_models.py index 47685dd1..304c9a2a 100644 --- a/libs/genai/tests/integration_tests/test_chat_models.py +++ b/libs/genai/tests/integration_tests/test_chat_models.py @@ -516,85 +516,88 @@ async def model_astream(context: str) -> List[BaseMessageChunk]: assert isinstance(result[0], AIMessageChunk) -def test_json_formatted_output() -> None: - """Test that json_mode works as expected with a json_schema.""" - - class MyModel(typing_extensions.TypedDict): - item: str - price: float +def test_output_matches_prompt_keys() -> None: + """ + Validate that when response_mime_type="application/json" is specified, + the output is a valid JSON format and contains the expected structure + based on the prompt. + """ llm = ChatGoogleGenerativeAI( model=_VISION_MODEL, response_mime_type="application/json", - response_schema=list[MyModel], ) + prompt_key_names = { + "list_key": "grocery_items", + "item_key": "item", + "price_key": "price", + } + messages = [ ("system", "You are a helpful assistant"), - ("human", "List the prices of common grocery items"), + ( + "human", + "Provide a list of grocery items with key 'grocery_items', " + "and for each item, use 'item' for the name and 'price' for the cost.", + ), ] response = llm.invoke(messages) - response_data = json.loads(response.content) - assert isinstance(response_data, list) - assert len(response.content) > 1 - assert isinstance(response.content[0], MyModel) - for item in response.content: - assert isinstance(item, MyModel) - - -def test_json_formatted_output_with_nested_schema() -> None: - """Test that json_mode works as expected with a nested json_schema.""" - - class PriceDetail(typing_extensions.TypedDict): - amount: float - currency: str - - class MyModel(TypedDict): - item: str - price: PriceDetail - - llm = ChatGoogleGenerativeAI( - model=_VISION_MODEL, - response_mime_type="application/json", - response_schema=list[MyModel], - ) - messages = [ - ("system", "You are a helpful assistant"), - ("human", "List the price details of a common grocery item"), - ] + # Ensure the response content is a JSON string + assert isinstance(response.content, str), "Response content should be a string." - response = llm.invoke(messages) - assert isinstance(response.content, list) - assert len(response.content) > 0 - assert isinstance(response.content[0], MyModel) - assert isinstance(response.content[0]["price"], PriceDetail) + # Attempt to parse the JSON + try: + response_data = json.loads(response.content) + except json.JSONDecodeError as e: + pytest.fail(f"Response is not valid JSON: {e}") + list_key = prompt_key_names["list_key"] + assert list_key in response_data, f"Expected key '{list_key}' is missing in the response." + grocery_items = response_data[list_key] + assert isinstance(grocery_items, list), f"'{list_key}' should be a list." -def test_enum_formatted_output() -> None: - """Test that response_mime_type works as expected with text/x.enum.""" + item_key = prompt_key_names["item_key"] + price_key = prompt_key_names["price_key"] - def test_enum_formatted_output() -> None: - """Test that response_mime_type works as expected with text/x.enum.""" - import enum + for item in grocery_items: + assert isinstance(item, dict), "Each grocery item should be a dictionary." + assert item_key in item, f"Each item should have the key '{item_key}'." + assert price_key in item, f"Each item should have the key '{price_key}'." - class Types(enum.Enum): - PERCUSSION = "Percussion" - STRING = "String" - WOODWIND = "Woodwind" - BRASS = "Brass" - KEYBOARD = "Keyboard" + print("Response matches the key names specified in the prompt.") - llm = ChatGoogleGenerativeAI( - model=_VISION_MODEL, response_mime_type="text/x.enum", response_schema=Types - ) +def test_validate_response_mime_type_and_schema() -> None: + """ + Test that `response_mime_type` and `response_schema` are validated correctly. + Ensure valid combinations of `response_mime_type` and `response_schema` pass, + and invalid ones raise appropriate errors. + """ - messages = [ - ("system", "You are a helpful assistant"), - ("human", "What kind of instrument is an organ?"), - ] + valid_model = ChatGoogleGenerativeAI( + model="gemini-1.5-pro", + response_mime_type="application/json", + response_schema={"type": "list", "items": {"type": "object"}}, # Example schema + ) - response = llm.invoke(messages) - assert isinstance(response.content, str) - assert response.content in Types._value2member_map_ + try: + valid_model.validate_environment() + except ValueError as e: + pytest.fail(f"Validation failed unexpectedly with valid parameters: {e}") + + with pytest.raises(ValueError, match="response_mime_type must be either .*"): + ChatGoogleGenerativeAI( + model="gemini-1.5-pro", + response_mime_type="invalid/type", + response_schema={"type": "list", "items": {"type": "object"}}, + ).validate_environment() + + try: + ChatGoogleGenerativeAI( + model="gemini-1.5-pro", + response_mime_type="application/json", + ).validate_environment() + except ValueError as e: + pytest.fail(f"Validation failed unexpectedly with a valid MIME type and no schema: {e}") \ No newline at end of file