From e26010b4c6acc342bd3925155b28e583b0d604f6 Mon Sep 17 00:00:00 2001 From: vishah02 Date: Thu, 28 Nov 2024 01:28:37 -0500 Subject: [PATCH] added support for JSON mode --- .../langchain_google_genai/chat_models.py | 84 ++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index f328ec4e..ec25620e 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -52,6 +52,7 @@ from google.generativeai.caching import CachedContent # type: ignore[import] from google.generativeai.types import Tool as GoogleTool # type: ignore[import] from google.generativeai.types import caching_types, content_types +from google.generativeai.types import generation_types # type: ignore[import] from google.generativeai.types.content_types import ( # type: ignore[import] FunctionDeclarationType, ToolDict, @@ -114,7 +115,7 @@ from langchain_google_genai._image_utils import ImageBytesLoader from langchain_google_genai.llms import _BaseGoogleGenerativeAI -from . import _genai_extension as genaix +from langchain_google_genai import _genai_extension as genaix IMAGE_TYPES: Tuple = () try: @@ -833,6 +834,52 @@ class Joke(BaseModel): 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}] } + + JSON Output and Schema Support: + .. code-block:: python + import typing_extensions as typing + + class Recipe(typing.TypedDict): + recipe_name: str + ingredients: list[str] + rating: float + + llm = ChatGoogleGenerativeAI( + model="gemini-1.5-pro", + temperature=0, + max_tokens=None, + timeout=None, + max_retries=2, + response_mime_type="application/json", + response_schema=list[Recipe] + ) + + # Define the messages + + messages = [ + ("system", "You are a helpful assistant"), + ("human", "List 2 recipes for a healthy breakfast"), + ] + + response = llm.invoke(messages) + print(json.dumps(json.loads(response.content), indent=4)) + + .. code-block:: python + [ + { + "ingredients": [ + "1/2 cup rolled oats", + "1 cup unsweetened almond milk", + "1/4 cup berries", + "1 tablespoon chia seeds", + "1/2 teaspoon cinnamon" + ], + "rating": 4.5, + "recipe_name": "Overnight Oats" + }, + ] + + """ # noqa: E501 @@ -853,6 +900,22 @@ class Joke(BaseModel): Gemini does not support system messages; any unsupported messages will raise an error.""" + response_mime_type: Optional[str] = None + """Optional. Output response mimetype of the generated candidate text. Only + supported in Gemini 1.5 and later models. Supported mimetype: + * "text/plain": (default) Text output. + * "application/json": JSON response in the candidates. + * "text/x.enum": Enum in plain text. + """ + + response_schema: Optional[Any] = None + """ Optional. Enforce an schema to the output. + The value of response_schema must be a either: + * A type hint annotation, as defined in the Python typing module module. + * An instance of genai.protos.Schema. + * An enum class + """ + cached_content: Optional[str] = None """The name of the cached content used as context to serve the prediction. @@ -892,6 +955,20 @@ def validate_environment(self) -> Self: if not self.model.startswith("models/"): self.model = f"models/{self.model}" + if self.response_mime_type is not None and self.response_mime_type not in [ + "text/plain", "application/json", "text/x.enum"]: + raise ValueError( + "response_mime_type must be either 'text/plain' " + "or 'application/json'" + ) + + if self.response_schema is not None: + if self.response_mime_type not in ["application/json", "text/x.enum"]: + raise ValueError( + "response_schema is only supported when response_mime_type is " + "'application/json or 'text/x.enum'" + ) + additional_headers = self.additional_headers or {} self.default_metadata = tuple(additional_headers.items()) client_info = get_client_info("ChatGoogleGenerativeAI") @@ -977,9 +1054,14 @@ def _prepare_params( "max_output_tokens": self.max_output_tokens, "top_k": self.top_k, "top_p": self.top_p, + "response_mime_type": self.response_mime_type, + "response_schema": self.response_schema, }.items() if v is not None } + if gen_config.get("response_schema"): + generation_types._normalize_schema(gen_config) + if generation_config: gen_config = {**gen_config, **generation_config} return GenerationConfig(**gen_config)