Skip to content

Commit

Permalink
[vertexai]: Fix context caching with tools & flakey test (#602)
Browse files Browse the repository at this point in the history
  • Loading branch information
kardiff18 authored Nov 19, 2024
1 parent 89cbd3c commit 9f520cd
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 3 deletions.
2 changes: 1 addition & 1 deletion libs/vertexai/langchain_google_vertexai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def create_context_cache(
tool_config = _format_tool_config(tool_config)

if tools is not None:
tools = _format_to_gapic_tool(tools)
tools = [_format_to_gapic_tool(tools)]

cached_content = caching.CachedContent.create(
model_name=model.full_model_name,
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/tests/integration_tests/test_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_create_structured_runnable() -> None:

@pytest.mark.release
def test_create_structured_runnable_with_prompt() -> None:
llm = ChatVertexAI(model_name=_DEFAULT_MODEL_NAME)
llm = ChatVertexAI(model_name=_DEFAULT_MODEL_NAME, temperature=0)
prompt = ChatPromptTemplate.from_template(
"Describe a random {class} and mention their name, {attr} and favorite food"
)
Expand Down
70 changes: 70 additions & 0 deletions libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,76 @@ def test_context_catching():
assert isinstance(response.content, str)


@pytest.mark.extended
def test_context_catching_tools():
from langchain import agents

@tool
def get_secret_number() -> int:
"""Gets secret number."""
return 747

tools = [get_secret_number]
system_instruction = """
You are an expert researcher. You always stick to the facts in the sources
provided, and never make up new facts.
You have a get_secret_number function available. Use this tool if someone asks
for the secret number.
Now look at these research papers, and answer the following questions.
"""

cached_content = create_context_cache(
model=ChatVertexAI(
model_name="gemini-1.5-pro-001",
),
messages=[
SystemMessage(content=system_instruction),
HumanMessage(
content=[
{
"type": "image_url",
"image_url": {
"url": "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf",
},
},
{
"type": "image_url",
"image_url": {
"url": "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf"
},
},
]
),
],
tools=tools,
)

chat = ChatVertexAI(
model_name="gemini-1.5-pro-001",
cached_content=cached_content,
)

prompt = ChatPromptTemplate.from_messages(
[
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
agent = agents.create_tool_calling_agent(
llm=chat,
tools=tools,
prompt=prompt,
)
agent_executor = agents.AgentExecutor( # type: ignore[call-arg]
agent=agent, tools=tools, verbose=False, stream_runnable=False
)
response = agent_executor.invoke({"input": "what is the secret number?"})
assert isinstance(response["output"], str)


@pytest.mark.release
def test_json_serializable() -> None:
llm = ChatVertexAI(
Expand Down
6 changes: 5 additions & 1 deletion libs/vertexai/tests/integration_tests/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ def chat_model_class(self) -> Type[BaseChatModel]:

@property
def chat_model_params(self) -> dict:
return {"model_name": "gemini-1.5-pro-001", "rate_limiter": rate_limiter}
return {
"model_name": "gemini-1.5-pro-001",
"rate_limiter": rate_limiter,
"temperature": 0,
}

@property
def supports_image_inputs(self) -> bool:
Expand Down

0 comments on commit 9f520cd

Please sign in to comment.