diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 6383b5ea..b855e976 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -29,6 +29,7 @@ FunctionMessage, HumanMessage, SystemMessage, + ToolMessage, ) from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_functions import ( @@ -199,6 +200,16 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: }, ) ] + elif isinstance(message, ToolMessage): + role = "function" + parts = [ + Part.from_function_response( + name=message.name, + response={ + "content": message.content, + }, + ) + ] else: raise ValueError( f"Unexpected message with type {type(message)} at the position {i}." @@ -301,7 +312,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): "Underlying model name." examples: Optional[List[BaseMessage]] = None convert_system_message_to_human: bool = False - """[Deprecated] Since new Gemini models support setting a System Message, + """[Deprecated] Since new Gemini models support setting a System Message, setting this parameter to True is discouraged. """ diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index 344d7472..9265ad4f 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -9,6 +9,7 @@ from google.cloud.aiplatform_v1beta1.types import ( Content, FunctionCall, + FunctionResponse, Part, ) from google.cloud.aiplatform_v1beta1.types import ( @@ -16,8 +17,10 @@ ) from langchain_core.messages import ( AIMessage, + FunctionMessage, HumanMessage, SystemMessage, + ToolMessage, ) from vertexai.generative_models import ( # type: ignore Candidate, @@ -184,6 +187,82 @@ def test_parse_history_gemini_converted_message() -> None: assert history[1].parts[0].text == text_answer1 +def test_parse_history_gemini_function() -> None: + system_input = "You're supposed to answer math questions." + text_question1 = "Which is bigger 2+2 or 2*2?" + function_name = "calculator" + function_call_1 = { + "name": function_name, + "arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "+"}), + } + function_answer1 = json.dumps({"result": 4}) + function_call_2 = { + "name": function_name, + "arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "*"}), + } + function_answer2 = json.dumps({"result": 4}) + text_answer1 = "They are same" + + system_message = SystemMessage(content=system_input) + message1 = HumanMessage(content=text_question1) + message2 = AIMessage( + content="", + additional_kwargs={ + "function_call": function_call_1, + }, + ) + message3 = ToolMessage( + name="calculator", content=function_answer1, tool_call_id="1" + ) + message4 = AIMessage( + content="", + additional_kwargs={ + "function_call": function_call_2, + }, + ) + message5 = FunctionMessage(name="calculator", content=function_answer2) + message6 = AIMessage(content=text_answer1) + messages = [ + system_message, + message1, + message2, + message3, + message4, + message5, + message6, + ] + system_instructions, history = _parse_chat_history_gemini(messages) + assert len(history) == 6 + assert system_instructions and system_instructions.parts[0].text == system_input + assert history[0].role == "user" + assert history[0].parts[0].text == text_question1 + + assert history[1].role == "model" + assert history[1].parts[0].function_call == FunctionCall( + name=function_call_1["name"], args=json.loads(function_call_1["arguments"]) + ) + + assert history[2].role == "function" + assert history[2].parts[0].function_response == FunctionResponse( + name=function_call_1["name"], + response={"content": function_answer1}, + ) + + assert history[3].role == "model" + assert history[3].parts[0].function_call == FunctionCall( + name=function_call_2["name"], args=json.loads(function_call_2["arguments"]) + ) + + assert history[4].role == "function" + assert history[2].parts[0].function_response == FunctionResponse( + name=function_call_2["name"], + response={"content": function_answer2}, + ) + + assert history[5].role == "model" + assert history[5].parts[0].text == text_answer1 + + def test_default_params_palm() -> None: user_prompt = "Hello"