From 0b5d16a3d7b54c49a904de7728007084e02c2bee Mon Sep 17 00:00:00 2001 From: Jeff Park Date: Wed, 20 Nov 2024 00:37:41 -0500 Subject: [PATCH] vertexai: refactor: simplify content processing in anthropic formatter (#601) --- .../_anthropic_parsers.py | 19 +++++-- .../langchain_google_vertexai/model_garden.py | 23 +++++--- libs/vertexai/poetry.lock | 55 +++++++++++++++++-- libs/vertexai/pyproject.toml | 1 + .../integration_tests/test_model_garden.py | 2 +- .../tests/unit_tests/test_chat_models.py | 47 ++++++++++++++++ 6 files changed, 129 insertions(+), 18 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py b/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py index a31759bf..d979f20e 100644 --- a/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py +++ b/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Type +from typing import Any, List, Optional, Type, Union from langchain_core.messages import AIMessage, ToolCall from langchain_core.messages.tool import tool_call @@ -55,11 +55,18 @@ def _pydantic_parse(self, tool_call: dict) -> BaseModel: return cls_(**tool_call["args"]) -def _extract_tool_calls(content: List[dict]) -> List[ToolCall]: - tool_calls = [] - for block in content: - if block["type"] == "tool_use": +def _extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]: + """Extract tool calls from a list of content blocks.""" + if isinstance(content, list): + tool_calls = [] + for block in content: + if isinstance(block, str): + continue + if block["type"] != "tool_use": + continue tool_calls.append( tool_call(name=block["name"], args=block["input"], id=block["id"]) ) - return tool_calls + return tool_calls + else: + return [] diff --git a/libs/vertexai/langchain_google_vertexai/model_garden.py b/libs/vertexai/langchain_google_vertexai/model_garden.py index 7afa8071..ab553f6e 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden.py @@ -153,15 +153,20 @@ def validate_environment(self) -> Self: AsyncAnthropicVertex, ) + if self.project is None: + raise ValueError("project is required for ChatAnthropicVertex") + + project_id: str = self.project + self.client = AnthropicVertex( - project_id=self.project, + project_id=project_id, region=self.location, max_retries=self.max_retries, access_token=self.access_token, credentials=self.credentials, ) self.async_client = AsyncAnthropicVertex( - project_id=self.project, + project_id=project_id, region=self.location, max_retries=self.max_retries, access_token=self.access_token, @@ -205,14 +210,18 @@ def _format_params( def _format_output(self, data: Any, **kwargs: Any) -> ChatResult: data_dict = data.model_dump() - content = [c for c in data_dict["content"] if c["type"] != "tool_use"] - content = content[0]["text"] if len(content) == 1 else content + content = data_dict["content"] llm_output = { k: v for k, v in data_dict.items() if k not in ("content", "role", "type") } - tool_calls = _extract_tool_calls(data_dict["content"]) - if tool_calls: - msg = AIMessage(content=content, tool_calls=tool_calls) + if len(content) == 1 and content[0]["type"] == "text": + msg = AIMessage(content=content[0]["text"]) + elif any(block["type"] == "tool_use" for block in content): + tool_calls = _extract_tool_calls(content) + msg = AIMessage( + content=content, + tool_calls=tool_calls, + ) else: msg = AIMessage(content=content) # Collect token usage diff --git a/libs/vertexai/poetry.lock b/libs/vertexai/poetry.lock index 96425c97..b3b85815 100644 --- a/libs/vertexai/poetry.lock +++ b/libs/vertexai/poetry.lock @@ -152,7 +152,7 @@ files = [ name = "anthropic" version = "0.37.1" description = "The official Python library for the anthropic API" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "anthropic-0.37.1-py3-none-any.whl", hash = "sha256:8f550f88906823752e2abf99fbe491fbc8d40bce4cb26b9663abdf7be990d721"}, @@ -393,7 +393,7 @@ files = [ name = "distro" version = "1.9.0" description = "Distro - an OS platform information API" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, @@ -1255,7 +1255,7 @@ files = [ name = "jiter" version = "0.7.0" description = "Fast iterable JSON parser." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "jiter-0.7.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:e14027f61101b3f5e173095d9ecf95c1cac03ffe45a849279bde1d97e559e314"}, @@ -2452,14 +2452,61 @@ description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ + {file = "SQLAlchemy-2.0.36-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:59b8f3adb3971929a3e660337f5dacc5942c2cdb760afcabb2614ffbda9f9f72"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37350015056a553e442ff672c2d20e6f4b6d0b2495691fa239d8aa18bb3bc908"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8318f4776c85abc3f40ab185e388bee7a6ea99e7fa3a30686580b209eaa35c08"}, {file = "SQLAlchemy-2.0.36-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c245b1fbade9c35e5bd3b64270ab49ce990369018289ecfde3f9c318411aaa07"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:69f93723edbca7342624d09f6704e7126b152eaed3cdbb634cb657a54332a3c5"}, {file = "SQLAlchemy-2.0.36-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f9511d8dd4a6e9271d07d150fb2f81874a3c8c95e11ff9af3a2dfc35fe42ee44"}, {file = "SQLAlchemy-2.0.36-cp310-cp310-win32.whl", hash = "sha256:c3f3631693003d8e585d4200730616b78fafd5a01ef8b698f6967da5c605b3fa"}, {file = "SQLAlchemy-2.0.36-cp310-cp310-win_amd64.whl", hash = "sha256:a86bfab2ef46d63300c0f06936bd6e6c0105faa11d509083ba8f2f9d237fb5b5"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fd3a55deef00f689ce931d4d1b23fa9f04c880a48ee97af488fd215cf24e2a6c"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4f5e9cd989b45b73bd359f693b935364f7e1f79486e29015813c338450aa5a71"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0ddd9db6e59c44875211bc4c7953a9f6638b937b0a88ae6d09eb46cced54eff"}, {file = "SQLAlchemy-2.0.36-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2519f3a5d0517fc159afab1015e54bb81b4406c278749779be57a569d8d1bb0d"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59b1ee96617135f6e1d6f275bbe988f419c5178016f3d41d3c0abb0c819f75bb"}, {file = "SQLAlchemy-2.0.36-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:39769a115f730d683b0eb7b694db9789267bcd027326cccc3125e862eb03bfd8"}, {file = "SQLAlchemy-2.0.36-cp311-cp311-win32.whl", hash = "sha256:66bffbad8d6271bb1cc2f9a4ea4f86f80fe5e2e3e501a5ae2a3dc6a76e604e6f"}, {file = "SQLAlchemy-2.0.36-cp311-cp311-win_amd64.whl", hash = "sha256:23623166bfefe1487d81b698c423f8678e80df8b54614c2bf4b4cfcd7c711959"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f7b64e6ec3f02c35647be6b4851008b26cff592a95ecb13b6788a54ef80bbdd4"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:46331b00096a6db1fdc052d55b101dbbfc99155a548e20a0e4a8e5e4d1362855"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdf3386a801ea5aba17c6410dd1dc8d39cf454ca2565541b5ac42a84e1e28f53"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac9dfa18ff2a67b09b372d5db8743c27966abf0e5344c555d86cc7199f7ad83a"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:90812a8933df713fdf748b355527e3af257a11e415b613dd794512461eb8a686"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1bc330d9d29c7f06f003ab10e1eaced295e87940405afe1b110f2eb93a233588"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-win32.whl", hash = "sha256:79d2e78abc26d871875b419e1fd3c0bca31a1cb0043277d0d850014599626c2e"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-win_amd64.whl", hash = "sha256:b544ad1935a8541d177cb402948b94e871067656b3a0b9e91dbec136b06a2ff5"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b5cc79df7f4bc3d11e4b542596c03826063092611e481fcf1c9dfee3c94355ef"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3c01117dd36800f2ecaa238c65365b7b16497adc1522bf84906e5710ee9ba0e8"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9bc633f4ee4b4c46e7adcb3a9b5ec083bf1d9a97c1d3854b92749d935de40b9b"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e46ed38affdfc95d2c958de328d037d87801cfcbea6d421000859e9789e61c2"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b2985c0b06e989c043f1dc09d4fe89e1616aadd35392aea2844f0458a989eacf"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4a121d62ebe7d26fec9155f83f8be5189ef1405f5973ea4874a26fab9f1e262c"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-win32.whl", hash = "sha256:0572f4bd6f94752167adfd7c1bed84f4b240ee6203a95e05d1e208d488d0d436"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-win_amd64.whl", hash = "sha256:8c78ac40bde930c60e0f78b3cd184c580f89456dd87fc08f9e3ee3ce8765ce88"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:be9812b766cad94a25bc63bec11f88c4ad3629a0cec1cd5d4ba48dc23860486b"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50aae840ebbd6cdd41af1c14590e5741665e5272d2fee999306673a1bb1fdb4d"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4557e1f11c5f653ebfdd924f3f9d5ebfc718283b0b9beebaa5dd6b77ec290971"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:07b441f7d03b9a66299ce7ccf3ef2900abc81c0db434f42a5694a37bd73870f2"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:28120ef39c92c2dd60f2721af9328479516844c6b550b077ca450c7d7dc68575"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-win32.whl", hash = "sha256:b81ee3d84803fd42d0b154cb6892ae57ea6b7c55d8359a02379965706c7efe6c"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-win_amd64.whl", hash = "sha256:f942a799516184c855e1a32fbc7b29d7e571b52612647866d4ec1c3242578fcb"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3d6718667da04294d7df1670d70eeddd414f313738d20a6f1d1f379e3139a545"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:72c28b84b174ce8af8504ca28ae9347d317f9dba3999e5981a3cd441f3712e24"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b11d0cfdd2b095e7b0686cf5fabeb9c67fae5b06d265d8180715b8cfa86522e3"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e32092c47011d113dc01ab3e1d3ce9f006a47223b18422c5c0d150af13a00687"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6a440293d802d3011028e14e4226da1434b373cbaf4a4bbb63f845761a708346"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c54a1e53a0c308a8e8a7dffb59097bff7facda27c70c286f005327f21b2bd6b1"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-win32.whl", hash = "sha256:1e0d612a17581b6616ff03c8e3d5eff7452f34655c901f75d62bd86449d9750e"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-win_amd64.whl", hash = "sha256:8958b10490125124463095bbdadda5aa22ec799f91958e410438ad6c97a7b793"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dc022184d3e5cacc9579e41805a681187650e170eb2fd70e28b86192a479dcaa"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b817d41d692bf286abc181f8af476c4fbef3fd05e798777492618378448ee689"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4e46a888b54be23d03a89be510f24a7652fe6ff660787b96cd0e57a4ebcb46d"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4ae3005ed83f5967f961fd091f2f8c5329161f69ce8480aa8168b2d7fe37f06"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03e08af7a5f9386a43919eda9de33ffda16b44eb11f3b313e6822243770e9763"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:3dbb986bad3ed5ceaf090200eba750b5245150bd97d3e67343a3cfed06feecf7"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-win32.whl", hash = "sha256:9fe53b404f24789b5ea9003fc25b9a3988feddebd7e7b369c8fac27ad6f52f28"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-win_amd64.whl", hash = "sha256:af148a33ff0349f53512a049c6406923e4e02bf2f26c5fb285f143faf4f0e46a"}, {file = "SQLAlchemy-2.0.36-py3-none-any.whl", hash = "sha256:fddbe92b4760c6f5d48162aef14824add991aeda8ddadb3c31d56eb15ca69f8e"}, {file = "sqlalchemy-2.0.36.tar.gz", hash = "sha256:7f2767680b6d2398aea7082e45a774b2b0767b5c8d8ffb9c8b683088ea9b29c5"}, ] @@ -2890,4 +2937,4 @@ mistral = ["langchain-mistralai"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "12cbda598b827846eef4fc9b3d6a44ec9d187e2b8d5bf37b03f1722954cfcb90" +content-hash = "8931477e695bb3f50481e75964dddfd6b644582360ac9e4b40bb8b445d8ed9f7" diff --git a/libs/vertexai/pyproject.toml b/libs/vertexai/pyproject.toml index bd80636f..916f0963 100644 --- a/libs/vertexai/pyproject.toml +++ b/libs/vertexai/pyproject.toml @@ -42,6 +42,7 @@ numpy = [ google-api-python-client = "^2.117.0" langchain = "^0.3.7" langchain-tests = "0.3.1" +anthropic = { extras = ["vertexai"], version = ">=0.35.0,<1" } [tool.codespell] diff --git a/libs/vertexai/tests/integration_tests/test_model_garden.py b/libs/vertexai/tests/integration_tests/test_model_garden.py index a0946a6c..6b90d591 100644 --- a/libs/vertexai/tests/integration_tests/test_model_garden.py +++ b/libs/vertexai/tests/integration_tests/test_model_garden.py @@ -168,7 +168,7 @@ async def test_anthropic_async() -> None: def _check_tool_calls(response: BaseMessage, expected_name: str) -> None: """Check tool calls are as expected.""" assert isinstance(response, AIMessage) - assert isinstance(response.content, str) + assert isinstance(response.content, list) tool_calls = response.tool_calls assert len(tool_calls) == 1 tool_call = tool_calls[0] diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index e9da74ea..06567e71 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -41,6 +41,7 @@ _parse_examples, _parse_response_candidate, ) +from langchain_google_vertexai.model_garden import ChatAnthropicVertex def test_init() -> None: @@ -1067,3 +1068,49 @@ def test_init_client_with_custom_api() -> None: transport = mock_prediction_service.call_args.kwargs["transport"] assert client_options.api_endpoint == "https://example.com" assert transport == "rest" + + +def test_anthropic_format_output() -> None: + """Test format output handles different content structures correctly.""" + + @dataclass + class Usage: + input_tokens: int + output_tokens: int + + @dataclass + class Message: + def model_dump(self): + return { + "content": [ + { + "type": "tool_use", + "id": "123", + "name": "calculator", + "input": {"number": 42}, + } + ], + "model": "baz", + "role": "assistant", + "usage": Usage(input_tokens=2, output_tokens=1), + "type": "message", + } + + usage: Usage + + test_msg = Message(usage=Usage(input_tokens=2, output_tokens=1)) + + model = ChatAnthropicVertex(project="test-project", location="test-location") + result = model._format_output(test_msg) + + message = result.generations[0].message + assert isinstance(message, AIMessage) + assert message.content == test_msg.model_dump()["content"] + assert len(message.tool_calls) == 1 + assert message.tool_calls[0]["name"] == "calculator" + assert message.tool_calls[0]["args"] == {"number": 42} + assert message.usage_metadata == { + "input_tokens": 2, + "output_tokens": 1, + "total_tokens": 3, + }