Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Genai: Integrate JSON Mode for Gemini 1.5 (flash) #625

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
93 changes: 90 additions & 3 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@
)
from google.generativeai.caching import CachedContent # type: ignore[import]
from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
from google.generativeai.types import caching_types, content_types
from google.generativeai.types import (
caching_types,
content_types,
generation_types, # type: ignore[import]
)
from google.generativeai.types.content_types import ( # type: ignore[import]
FunctionDeclarationType,
ToolDict,
Expand Down Expand Up @@ -94,6 +98,7 @@
)
from typing_extensions import Self

from langchain_google_genai import _genai_extension as genaix
from langchain_google_genai._common import (
GoogleGenerativeAIError,
SafetySettingDict,
Expand All @@ -110,8 +115,6 @@
from langchain_google_genai._image_utils import ImageBytesLoader
from langchain_google_genai.llms import _BaseGoogleGenerativeAI

from . import _genai_extension as genaix

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -767,6 +770,52 @@ class Joke(BaseModel):
'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]
}

JSON Output and Schema Support:
.. code-block:: python
import typing_extensions as typing

class Recipe(typing.TypedDict):
recipe_name: str
ingredients: list[str]
rating: float

llm = ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
response_mime_type="application/json",
response_schema=list[Recipe]
)

# Define the messages

messages = [
("system", "You are a helpful assistant"),
("human", "List 2 recipes for a healthy breakfast"),
]

response = llm.invoke(messages)
print(json.dumps(json.loads(response.content), indent=4))

.. code-block:: python
[
{
"ingredients": [
"1/2 cup rolled oats",
"1 cup unsweetened almond milk",
"1/4 cup berries",
"1 tablespoon chia seeds",
"1/2 teaspoon cinnamon"
],
"rating": 4.5,
"recipe_name": "Overnight Oats"
},
]



""" # noqa: E501

client: Any = Field(default=None, exclude=True) #: :meta private:
Expand All @@ -786,6 +835,22 @@ class Joke(BaseModel):
Gemini does not support system messages; any unsupported messages will
raise an error."""

response_mime_type: Optional[str] = None
"""Optional. Output response mimetype of the generated candidate text. Only
supported in Gemini 1.5 and later models. Supported mimetype:
* "text/plain": (default) Text output.
* "application/json": JSON response in the candidates.
* "text/x.enum": Enum in plain text.
"""

response_schema: Optional[Any] = None
""" Optional. Enforce an schema to the output.
The value of response_schema must be a either:
* A type hint annotation, as defined in the Python typing module module.
* An instance of genai.protos.Schema.
* An enum class
"""

cached_content: Optional[str] = None
"""The name of the cached content used as context to serve the prediction.

Expand Down Expand Up @@ -825,6 +890,23 @@ def validate_environment(self) -> Self:
if not self.model.startswith("models/"):
self.model = f"models/{self.model}"

if self.response_mime_type is not None and self.response_mime_type not in [
"text/plain",
"application/json",
"text/x.enum",
]:
raise ValueError(
"response_mime_type must be either 'text/plain' "
"or 'application/json'"
)

if self.response_schema is not None:
if self.response_mime_type not in ["application/json", "text/x.enum"]:
raise ValueError(
"response_schema is only supported when response_mime_type is "
"'application/json or 'text/x.enum'"
)

additional_headers = self.additional_headers or {}
self.default_metadata = tuple(additional_headers.items())
client_info = get_client_info("ChatGoogleGenerativeAI")
Expand Down Expand Up @@ -910,9 +992,14 @@ def _prepare_params(
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"response_mime_type": self.response_mime_type,
"response_schema": self.response_schema,
}.items()
if v is not None
}
if gen_config.get("response_schema"):
generation_types._normalize_schema(gen_config)

if generation_config:
gen_config = {**gen_config, **generation_config}
return GenerationConfig(**gen_config)
Expand Down
90 changes: 90 additions & 0 deletions libs/genai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,93 @@ async def model_astream(context: str) -> List[BaseMessageChunk]:
result = asyncio.run(model_astream("How can you help me?"))
assert len(result) > 0
assert isinstance(result[0], AIMessageChunk)


def test_output_matches_prompt_keys() -> None:
"""
Validate that when response_mime_type="application/json" is specified,
the output is a valid JSON format and contains the expected structure
based on the prompt.
"""

llm = ChatGoogleGenerativeAI(
model=_VISION_MODEL,
response_mime_type="application/json",
)

prompt_key_names = {
"list_key": "grocery_items",
"item_key": "item",
"price_key": "price",
}

messages = [
("system", "You are a helpful assistant"),
(
"human",
"Provide a list of grocery items with key 'grocery_items', "
"and for each item, use 'item' for the name and 'price' for the cost.",
),
]

response = llm.invoke(messages)

# Ensure the response content is a JSON string
assert isinstance(response.content, str), "Response content should be a string."

# Attempt to parse the JSON
try:
response_data = json.loads(response.content)
except json.JSONDecodeError as e:
pytest.fail(f"Response is not valid JSON: {e}")

list_key = prompt_key_names["list_key"]
assert list_key in response_data, \
f"Expected key '{list_key}' is missing in the response."
grocery_items = response_data[list_key]
assert isinstance(grocery_items, list), f"'{list_key}' should be a list."

item_key = prompt_key_names["item_key"]
price_key = prompt_key_names["price_key"]

for item in grocery_items:
assert isinstance(item, dict), "Each grocery item should be a dictionary."
assert item_key in item, f"Each item should have the key '{item_key}'."
assert price_key in item, f"Each item should have the key '{price_key}'."

print("Response matches the key names specified in the prompt.")

def test_validate_response_mime_type_and_schema() -> None:
"""
Test that `response_mime_type` and `response_schema` are validated correctly.
Ensure valid combinations of `response_mime_type` and `response_schema` pass,
and invalid ones raise appropriate errors.
"""

valid_model = ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
response_mime_type="application/json",
response_schema={"type": "list", "items": {"type": "object"}}, # Example schema
)

try:
valid_model.validate_environment()
except ValueError as e:
pytest.fail(f"Validation failed unexpectedly with valid parameters: {e}")

with pytest.raises(ValueError, match="response_mime_type must be either .*"):
ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
response_mime_type="invalid/type",
response_schema={"type": "list", "items": {"type": "object"}},
).validate_environment()

try:
ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
response_mime_type="application/json",
).validate_environment()
except ValueError as e:
pytest.fail(
f"Validation failed unexpectedly with a valid MIME type and no schema: {e}"
)
Loading