Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vertexai: refactor: simplify content processing in anthropic formatter #601

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, Type
from typing import Any, List, Optional, Type, Union

from langchain_core.messages import AIMessage, ToolCall
from langchain_core.messages.tool import tool_call
Expand Down Expand Up @@ -55,11 +55,18 @@ def _pydantic_parse(self, tool_call: dict) -> BaseModel:
return cls_(**tool_call["args"])


def _extract_tool_calls(content: List[dict]) -> List[ToolCall]:
tool_calls = []
for block in content:
if block["type"] == "tool_use":
def _extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]:
"""Extract tool calls from a list of content blocks."""
if isinstance(content, list):
tool_calls = []
for block in content:
if isinstance(block, str):
continue
if block["type"] != "tool_use":
continue
tool_calls.append(
tool_call(name=block["name"], args=block["input"], id=block["id"])
)
return tool_calls
return tool_calls
else:
return []
23 changes: 16 additions & 7 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,20 @@ def validate_environment(self) -> Self:
AsyncAnthropicVertex,
)

if self.project is None:
raise ValueError("project is required for ChatAnthropicVertex")

project_id: str = self.project

self.client = AnthropicVertex(
project_id=self.project,
project_id=project_id,
region=self.location,
max_retries=self.max_retries,
access_token=self.access_token,
credentials=self.credentials,
)
self.async_client = AsyncAnthropicVertex(
project_id=self.project,
project_id=project_id,
region=self.location,
max_retries=self.max_retries,
access_token=self.access_token,
Expand Down Expand Up @@ -205,14 +210,18 @@ def _format_params(

def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add a test for _format_output similar to the test that we have in the main lanchain library: https://github.com/langchain-ai/langchain/blob/f1222739f88bfdf37513af146da6b9dbf2a091c4/libs/partners/anthropic/tests/unit_tests/test_chat_models.py#L87-L109

We should base the input for the test here on what we noticed was causing the errors. So instead of passing only a TextBlock, we should only pass a ToolUseBlock like we do in our test. I suspect this will fail in the previous implementation.

What do you think?

data_dict = data.model_dump()
content = [c for c in data_dict["content"] if c["type"] != "tool_use"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main difference, here, we would only have elements in content for which the type was not tool_use. This would break when there is only one element, and that element had the this type set. Causing content to be empty on L215

Copy link
Contributor Author

@jfypk jfypk Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome -- thanks for calling this out @Reprazent! updated the description as well.

content = content[0]["text"] if len(content) == 1 else content
content = data_dict["content"]
llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
}
tool_calls = _extract_tool_calls(data_dict["content"])
if tool_calls:
msg = AIMessage(content=content, tool_calls=tool_calls)
if len(content) == 1 and content[0]["type"] == "text":
msg = AIMessage(content=content[0]["text"])
elif any(block["type"] == "tool_use" for block in content):
tool_calls = _extract_tool_calls(content)
msg = AIMessage(
content=content,
tool_calls=tool_calls,
)
else:
msg = AIMessage(content=content)
# Collect token usage
Expand Down
55 changes: 51 additions & 4 deletions libs/vertexai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/vertexai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ numpy = [
google-api-python-client = "^2.117.0"
langchain = "^0.3.7"
langchain-tests = "0.3.1"
anthropic = { extras = ["vertexai"], version = ">=0.35.0,<1" }


[tool.codespell]
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/tests/integration_tests/test_model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async def test_anthropic_async() -> None:
def _check_tool_calls(response: BaseMessage, expected_name: str) -> None:
"""Check tool calls are as expected."""
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert isinstance(response.content, list)
tool_calls = response.tool_calls
assert len(tool_calls) == 1
tool_call = tool_calls[0]
Expand Down
47 changes: 47 additions & 0 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_parse_examples,
_parse_response_candidate,
)
from langchain_google_vertexai.model_garden import ChatAnthropicVertex


def test_init() -> None:
Expand Down Expand Up @@ -1067,3 +1068,49 @@ def test_init_client_with_custom_api() -> None:
transport = mock_prediction_service.call_args.kwargs["transport"]
assert client_options.api_endpoint == "https://example.com"
assert transport == "rest"


def test_anthropic_format_output() -> None:
"""Test format output handles different content structures correctly."""

@dataclass
class Usage:
input_tokens: int
output_tokens: int

@dataclass
class Message:
def model_dump(self):
return {
"content": [
{
"type": "tool_use",
"id": "123",
"name": "calculator",
"input": {"number": 42},
}
],
"model": "baz",
"role": "assistant",
"usage": Usage(input_tokens=2, output_tokens=1),
"type": "message",
}

usage: Usage

test_msg = Message(usage=Usage(input_tokens=2, output_tokens=1))

model = ChatAnthropicVertex(project="test-project", location="test-location")
result = model._format_output(test_msg)

message = result.generations[0].message
assert isinstance(message, AIMessage)
assert message.content == test_msg.model_dump()["content"]
assert len(message.tool_calls) == 1
assert message.tool_calls[0]["name"] == "calculator"
assert message.tool_calls[0]["args"] == {"number": 42}
assert message.usage_metadata == {
"input_tokens": 2,
"output_tokens": 1,
"total_tokens": 3,
}
Loading