diff --git a/libs/vertexai/langchain_google_vertexai/_base.py b/libs/vertexai/langchain_google_vertexai/_base.py index 32d895a4..69d531c4 100644 --- a/libs/vertexai/langchain_google_vertexai/_base.py +++ b/libs/vertexai/langchain_google_vertexai/_base.py @@ -24,8 +24,9 @@ from langchain_core.outputs import Generation, LLMResult from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Self -from vertexai.generative_models._generative_models import ( # type: ignore +from vertexai.generative_models._generative_models import ( # type: ignore # type: ignore SafetySettingsType, + _convert_schema_dict_to_gapic, ) from vertexai.language_models import ( # type: ignore[import-untyped] TextGenerationModel, @@ -284,9 +285,33 @@ def _prepare_params( stop_sequences = stop or self.stop params_mapping = {"n": "candidate_count"} params = {params_mapping.get(k, k): v for k, v in kwargs.items()} + params = {**self._default_params, "stop_sequences": stop_sequences, **params} + if stream or self.streaming: params.pop("candidate_count") + + if "response_schema" in params: + + params["response_schema"] = _convert_schema_dict_to_gapic( + params["response_schema"]) + + if "response_mime_type" not in params: + error_message = ( + "`response_mime_type` must be set when `response_schema`" + " is specified." + ) + raise ValueError(error_message) + + if "response_mime_type" in params: + allowed_mime_types = ("application/json", "text/x.enum") + if params["response_mime_type"] not in allowed_mime_types: + error_message = ( + "`response_schema` is only supported when " + f"`response_mime_type` is set to one of {allowed_mime_types}" + ) + raise ValueError(error_message) + return params def get_num_tokens(self, text: str) -> int: diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index ebb9464a..b3bf7f24 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -1148,22 +1148,14 @@ def _is_gemini_advanced(self) -> bool: @property def _default_params(self) -> Dict[str, Any]: + updated_params = super()._default_params if self.response_mime_type is not None: updated_params["response_mime_type"] = self.response_mime_type if self.response_schema is not None: - allowed_mime_types = ("application/json", "text/x.enum") - if self.response_mime_type not in allowed_mime_types: - error_message = ( - "`response_schema` is only supported when " - f"`response_mime_type` is set to one of {allowed_mime_types}" - ) - raise ValueError(error_message) - - gapic_response_schema = _convert_schema_dict_to_gapic(self.response_schema) - updated_params["response_schema"] = gapic_response_schema + updated_params["response_schema"] = self.response_schema return updated_params @@ -1795,13 +1787,14 @@ class AnswerWithJustification(BaseModel): # that takes care of this if necessary. schema_json = schema.model_json_schema() schema_json = replace_defs_in_schema(schema_json) - self.response_schema = schema_json + response_schema = schema_json parser = PydanticOutputParser(pydantic_object=schema) else: parser = JsonOutputParser() - self.response_schema = schema - self.response_mime_type = "application/json" - llm: Runnable = self + response_schema = schema + response_mime_type = "application/json" + llm: Runnable = self.bind( + response_schema=response_schema, response_mime_type=response_mime_type) else: tool_name = _get_tool_name(schema) diff --git a/libs/vertexai/tests/integration_tests/test_chat_models.py b/libs/vertexai/tests/integration_tests/test_chat_models.py index f9c860a7..18df20e2 100644 --- a/libs/vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/vertexai/tests/integration_tests/test_chat_models.py @@ -670,6 +670,45 @@ class MyModel(BaseModel): "age": 27, } +@pytest.mark.release +def test_with_structured_output_thread_safety() -> None: + + model = ChatVertexAI( + model_name="gemini-1.5-pro-001" + ) + + structured_model = model.with_structured_output( + { + "title": "ID Extraction", + "description": "Extracts IDs from the input text.", + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": ("The type"), + "enum": [ + "ORDER_ID", + "PO_ID", + "INVOICE_ID", + "TRACKING_ID", + "OTHER", + ], + }, + "value": {"type": "string", "description": "The value of the entity."}, + }, + }, + }, + method="json_mode", + ) + response = structured_model.invoke(input="Order ID: AE22334455", temperature=0) + assert isinstance(response, list) + + response_2 = model.invoke(input="Order ID: AE22334455", temperature=0) + assert isinstance(response_2, AIMessage) + + @pytest.mark.release def test_chat_vertexai_gemini_with_structured_output_nested_model() -> None: