Skip to content

Commit

Permalink
added tool-message (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
alx13 authored Apr 10, 2024
1 parent aeb175c commit 8cac88a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 1 deletion.
13 changes: 12 additions & 1 deletion libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_functions import (
Expand Down Expand Up @@ -199,6 +200,16 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
},
)
]
elif isinstance(message, ToolMessage):
role = "function"
parts = [
Part.from_function_response(
name=message.name,
response={
"content": message.content,
},
)
]
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
Expand Down Expand Up @@ -301,7 +312,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
"Underlying model name."
examples: Optional[List[BaseMessage]] = None
convert_system_message_to_human: bool = False
"""[Deprecated] Since new Gemini models support setting a System Message,
"""[Deprecated] Since new Gemini models support setting a System Message,
setting this parameter to True is discouraged.
"""

Expand Down
79 changes: 79 additions & 0 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
from google.cloud.aiplatform_v1beta1.types import (
Content,
FunctionCall,
FunctionResponse,
Part,
)
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
)
from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from vertexai.generative_models import ( # type: ignore
Candidate,
Expand Down Expand Up @@ -184,6 +187,82 @@ def test_parse_history_gemini_converted_message() -> None:
assert history[1].parts[0].text == text_answer1


def test_parse_history_gemini_function() -> None:
system_input = "You're supposed to answer math questions."
text_question1 = "Which is bigger 2+2 or 2*2?"
function_name = "calculator"
function_call_1 = {
"name": function_name,
"arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "+"}),
}
function_answer1 = json.dumps({"result": 4})
function_call_2 = {
"name": function_name,
"arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "*"}),
}
function_answer2 = json.dumps({"result": 4})
text_answer1 = "They are same"

system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(
content="",
additional_kwargs={
"function_call": function_call_1,
},
)
message3 = ToolMessage(
name="calculator", content=function_answer1, tool_call_id="1"
)
message4 = AIMessage(
content="",
additional_kwargs={
"function_call": function_call_2,
},
)
message5 = FunctionMessage(name="calculator", content=function_answer2)
message6 = AIMessage(content=text_answer1)
messages = [
system_message,
message1,
message2,
message3,
message4,
message5,
message6,
]
system_instructions, history = _parse_chat_history_gemini(messages)
assert len(history) == 6
assert system_instructions and system_instructions.parts[0].text == system_input
assert history[0].role == "user"
assert history[0].parts[0].text == text_question1

assert history[1].role == "model"
assert history[1].parts[0].function_call == FunctionCall(
name=function_call_1["name"], args=json.loads(function_call_1["arguments"])
)

assert history[2].role == "function"
assert history[2].parts[0].function_response == FunctionResponse(
name=function_call_1["name"],
response={"content": function_answer1},
)

assert history[3].role == "model"
assert history[3].parts[0].function_call == FunctionCall(
name=function_call_2["name"], args=json.loads(function_call_2["arguments"])
)

assert history[4].role == "function"
assert history[2].parts[0].function_response == FunctionResponse(
name=function_call_2["name"],
response={"content": function_answer2},
)

assert history[5].role == "model"
assert history[5].parts[0].text == text_answer1


def test_default_params_palm() -> None:
user_prompt = "Hello"

Expand Down

0 comments on commit 8cac88a

Please sign in to comment.