From daec64d8a8892aafde7f82541b9fe5e556ea8cac Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Fri, 12 Apr 2024 15:08:58 +0200 Subject: [PATCH] added a check whether tool_config is allowed --- .../langchain_google_vertexai/chains.py | 19 ++++++- .../langchain_google_vertexai/chat_models.py | 32 ++++++++---- .../tests/integration_tests/test_chains.py | 51 +++++++++++++++++++ 3 files changed, 92 insertions(+), 10 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/chains.py b/libs/vertexai/langchain_google_vertexai/chains.py index 7d2bd318..66f924af 100644 --- a/libs/vertexai/langchain_google_vertexai/chains.py +++ b/libs/vertexai/langchain_google_vertexai/chains.py @@ -14,6 +14,9 @@ from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable +from vertexai.generative_models._generative_models import ( # type: ignore + ToolConfig, +) from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser @@ -50,7 +53,21 @@ def _create_structured_runnable_extra_step( *, prompt: Optional[BasePromptTemplate] = None, ) -> Runnable: - llm_with_functions = llm.bind(functions=functions) + names = [schema.schema()["title"] for schema in functions] + if hasattr(llm, "is_gemini_advanced") and llm._is_gemini_advanced: # type: ignore + llm_with_functions = llm.bind( + functions=functions, + tool_config={ + "function_calling_config": { + "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, + "allowed_function_names": names, + } + }, + ) + else: + llm_with_functions = llm.bind( + functions=functions, + ) parsing_prompt = ChatPromptTemplate.from_template( "You are a world class algorithm for recording entities.\nMake calls " "to the relevant function to record the entities in the following " diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 80b5e68c..642a7382 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -431,6 +431,17 @@ def validate_environment(cls, values: Dict) -> Dict: ) return values + @property + def _is_gemini_advanced(self) -> bool: + try: + if float(self.model_name.split("-")[1]) > 1.0: + return True + except ValueError: + pass + except IndexError: + pass + return False + def _generate( self, messages: List[BaseMessage], @@ -868,15 +879,18 @@ class AnswerWithJustification(BaseModel): parser = JsonOutputFunctionsParser() name = schema["name"] - llm = self.bind( - functions=[schema], - tool_config={ - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": [name], - } - }, - ) + if self._is_gemini_advanced: + llm = self.bind( + functions=[schema], + tool_config={ + "function_calling_config": { + "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, + "allowed_function_names": [name], + } + }, + ) + else: + llm = self.bind(functions=[schema]) if include_raw: parser_with_fallback = RunnablePassthrough.assign( parsed=itemgetter("raw") | parser, parsing_error=lambda _: None diff --git a/libs/vertexai/tests/integration_tests/test_chains.py b/libs/vertexai/tests/integration_tests/test_chains.py index 54ba6bb8..f7748ab8 100644 --- a/libs/vertexai/tests/integration_tests/test_chains.py +++ b/libs/vertexai/tests/integration_tests/test_chains.py @@ -1,6 +1,9 @@ from typing import Optional import pytest +from langchain_core.messages import ( + AIMessage, +) from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field @@ -47,3 +50,51 @@ def test_create_structured_runnable_with_prompt() -> None: ) res = chain.invoke({"class": "person", "attr": "age"}) assert isinstance(res, RecordPerson) + + +@pytest.mark.release +def test_reflection() -> None: + class Reflection(BaseModel): + reflections: str = Field( + description="The critique and reflections on the sufficiency, superfluency," + " and general quality of the response" + ) + score: int = Field( + description="Score from 0-10 on the quality of the candidate response.", + # gte=0, + # lte=10, + ) + found_solution: bool = Field( + description="Whether the response has fully solved the question or task." + ) + + def as_message(self): + return AIMessage( + content=f"Reasoning: {self.reflections}\nScore: {self.score}" + ) + + @property + def normalized_score(self) -> float: + return self.score / 10.0 + + llm = ChatVertexAI( + model_name="gemini-1.5-pro-preview-0409", + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "Reflect and grade the assistant response to the user question below.", + ), + ( + "user", + "Which planet is the closest to the Earth?", + ), + ("ai", "{input}"), + ] + ) + + reflection_llm_chain = prompt | llm.with_structured_output(Reflection) + res = reflection_llm_chain.invoke({"input": "Mars"}) + assert isinstance(res, Reflection)