diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 086fa23e..70610cef 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -46,7 +46,11 @@ ) from google.generativeai.caching import CachedContent # type: ignore[import] from google.generativeai.types import Tool as GoogleTool # type: ignore[import] -from google.generativeai.types import caching_types, content_types +from google.generativeai.types import ( + caching_types, + content_types, + generation_types, # type: ignore[import] +) from google.generativeai.types.content_types import ( # type: ignore[import] FunctionDeclarationType, ToolDict, @@ -94,6 +98,7 @@ ) from typing_extensions import Self +from langchain_google_genai import _genai_extension as genaix from langchain_google_genai._common import ( GoogleGenerativeAIError, SafetySettingDict, @@ -110,8 +115,6 @@ from langchain_google_genai._image_utils import ImageBytesLoader from langchain_google_genai.llms import _BaseGoogleGenerativeAI -from . import _genai_extension as genaix - logger = logging.getLogger(__name__) @@ -767,6 +770,52 @@ class Joke(BaseModel): 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}] } + JSON Output and Schema Support: + .. code-block:: python + import typing_extensions as typing + + class Recipe(typing.TypedDict): + recipe_name: str + ingredients: list[str] + rating: float + + llm = ChatGoogleGenerativeAI( + model="gemini-1.5-pro", + temperature=0, + max_tokens=None, + timeout=None, + max_retries=2, + response_mime_type="application/json", + response_schema=list[Recipe] + ) + + # Define the messages + + messages = [ + ("system", "You are a helpful assistant"), + ("human", "List 2 recipes for a healthy breakfast"), + ] + + response = llm.invoke(messages) + print(json.dumps(json.loads(response.content), indent=4)) + + .. code-block:: python + [ + { + "ingredients": [ + "1/2 cup rolled oats", + "1 cup unsweetened almond milk", + "1/4 cup berries", + "1 tablespoon chia seeds", + "1/2 teaspoon cinnamon" + ], + "rating": 4.5, + "recipe_name": "Overnight Oats" + }, + ] + + + """ # noqa: E501 client: Any = Field(default=None, exclude=True) #: :meta private: @@ -786,6 +835,22 @@ class Joke(BaseModel): Gemini does not support system messages; any unsupported messages will raise an error.""" + response_mime_type: Optional[str] = None + """Optional. Output response mimetype of the generated candidate text. Only + supported in Gemini 1.5 and later models. Supported mimetype: + * "text/plain": (default) Text output. + * "application/json": JSON response in the candidates. + * "text/x.enum": Enum in plain text. + """ + + response_schema: Optional[Any] = None + """ Optional. Enforce an schema to the output. + The value of response_schema must be a either: + * A type hint annotation, as defined in the Python typing module module. + * An instance of genai.protos.Schema. + * An enum class + """ + cached_content: Optional[str] = None """The name of the cached content used as context to serve the prediction. @@ -825,6 +890,23 @@ def validate_environment(self) -> Self: if not self.model.startswith("models/"): self.model = f"models/{self.model}" + if self.response_mime_type is not None and self.response_mime_type not in [ + "text/plain", + "application/json", + "text/x.enum", + ]: + raise ValueError( + "response_mime_type must be either 'text/plain' " + "or 'application/json'" + ) + + if self.response_schema is not None: + if self.response_mime_type not in ["application/json", "text/x.enum"]: + raise ValueError( + "response_schema is only supported when response_mime_type is " + "'application/json or 'text/x.enum'" + ) + additional_headers = self.additional_headers or {} self.default_metadata = tuple(additional_headers.items()) client_info = get_client_info("ChatGoogleGenerativeAI") @@ -910,9 +992,14 @@ def _prepare_params( "max_output_tokens": self.max_output_tokens, "top_k": self.top_k, "top_p": self.top_p, + "response_mime_type": self.response_mime_type, + "response_schema": self.response_schema, }.items() if v is not None } + if gen_config.get("response_schema"): + generation_types._normalize_schema(gen_config) + if generation_config: gen_config = {**gen_config, **generation_config} return GenerationConfig(**gen_config) diff --git a/libs/genai/tests/integration_tests/test_chat_models.py b/libs/genai/tests/integration_tests/test_chat_models.py index 3a15b9d4..3215e64b 100644 --- a/libs/genai/tests/integration_tests/test_chat_models.py +++ b/libs/genai/tests/integration_tests/test_chat_models.py @@ -513,3 +513,93 @@ async def model_astream(context: str) -> List[BaseMessageChunk]: result = asyncio.run(model_astream("How can you help me?")) assert len(result) > 0 assert isinstance(result[0], AIMessageChunk) + + +def test_output_matches_prompt_keys() -> None: + """ + Validate that when response_mime_type="application/json" is specified, + the output is a valid JSON format and contains the expected structure + based on the prompt. + """ + + llm = ChatGoogleGenerativeAI( + model=_VISION_MODEL, + response_mime_type="application/json", + ) + + prompt_key_names = { + "list_key": "grocery_items", + "item_key": "item", + "price_key": "price", + } + + messages = [ + ("system", "You are a helpful assistant"), + ( + "human", + "Provide a list of grocery items with key 'grocery_items', " + "and for each item, use 'item' for the name and 'price' for the cost.", + ), + ] + + response = llm.invoke(messages) + + # Ensure the response content is a JSON string + assert isinstance(response.content, str), "Response content should be a string." + + # Attempt to parse the JSON + try: + response_data = json.loads(response.content) + except json.JSONDecodeError as e: + pytest.fail(f"Response is not valid JSON: {e}") + + list_key = prompt_key_names["list_key"] + assert list_key in response_data, \ + f"Expected key '{list_key}' is missing in the response." + grocery_items = response_data[list_key] + assert isinstance(grocery_items, list), f"'{list_key}' should be a list." + + item_key = prompt_key_names["item_key"] + price_key = prompt_key_names["price_key"] + + for item in grocery_items: + assert isinstance(item, dict), "Each grocery item should be a dictionary." + assert item_key in item, f"Each item should have the key '{item_key}'." + assert price_key in item, f"Each item should have the key '{price_key}'." + + print("Response matches the key names specified in the prompt.") + +def test_validate_response_mime_type_and_schema() -> None: + """ + Test that `response_mime_type` and `response_schema` are validated correctly. + Ensure valid combinations of `response_mime_type` and `response_schema` pass, + and invalid ones raise appropriate errors. + """ + + valid_model = ChatGoogleGenerativeAI( + model="gemini-1.5-pro", + response_mime_type="application/json", + response_schema={"type": "list", "items": {"type": "object"}}, # Example schema + ) + + try: + valid_model.validate_environment() + except ValueError as e: + pytest.fail(f"Validation failed unexpectedly with valid parameters: {e}") + + with pytest.raises(ValueError, match="response_mime_type must be either .*"): + ChatGoogleGenerativeAI( + model="gemini-1.5-pro", + response_mime_type="invalid/type", + response_schema={"type": "list", "items": {"type": "object"}}, + ).validate_environment() + + try: + ChatGoogleGenerativeAI( + model="gemini-1.5-pro", + response_mime_type="application/json", + ).validate_environment() + except ValueError as e: + pytest.fail( + f"Validation failed unexpectedly with a valid MIME type and no schema: {e}" + ) \ No newline at end of file