Skip to content

Commit

Permalink
test for mistral tools (#456)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin authored Aug 20, 2024
1 parent 3557076 commit dedfe89
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions libs/vertexai/tests/integration_tests/test_maas.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import json
from typing import List

import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
ToolMessage,
)
from langchain_core.tools import tool

from langchain_google_vertexai.model_garden_maas.mistral import (
VertexModelGardenMistral,
Expand Down Expand Up @@ -48,3 +55,53 @@ async def test_astream(model_name: str) -> None:
output = llm.astream("What is the meaning of life?")
async for chunk in output:
assert isinstance(chunk, AIMessageChunk)


@pytest.mark.extended
@pytest.mark.parametrize("model_name", model_names)
async def test_tools(model_name: str) -> None:
@tool
def search(
question: str,
) -> str:
"""
Useful for when you need to answer questions or visit websites.
You should ask targeted questions.
"""
return "brown"

tools = [search]

llm = VertexModelGardenMistral(model=model_name, location="us-central1")
llm_with_search = llm.bind_tools(
tools=tools,
)
llm_with_search_force = llm_with_search.bind(
tool_choice={"type": "function", "function": {"name": "search"}}
)
request = HumanMessage(
content="Please tell the primary color of sparrow?",
)
response = llm_with_search_force.invoke([request])

assert isinstance(response, AIMessage)
tool_calls = response.tool_calls
assert len(tool_calls) == 1

tool_response = search("sparrow")
tool_messages: List[BaseMessage] = []

for tool_call in tool_calls:
assert tool_call["name"] == "search"
tool_message = ToolMessage(
name=tool_call["name"],
content=json.dumps(tool_response),
tool_call_id=(tool_call["id"] or ""),
)
tool_messages.append(tool_message)

result = llm_with_search.invoke([request, response] + tool_messages)

assert isinstance(result, AIMessage)
assert "brown" in result.content
assert len(result.tool_calls) == 0

0 comments on commit dedfe89

Please sign in to comment.