Skip to content

Commit

Permalink
fixed anthropic message conversion (#627)
Browse files Browse the repository at this point in the history
  • Loading branch information
alx13 authored Dec 1, 2024
1 parent 7f704aa commit b33d7b6
Show file tree
Hide file tree
Showing 4 changed files with 564 additions and 92 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
# It doesn't matter how you change it, any change will cause a cache-bust.
working-directory: ${{ inputs.working-directory }}
run: |
poetry install --with lint,typing
poetry install --with lint,typing --all-extras
- name: Install langchain editable
working-directory: ${{ inputs.working-directory }}
Expand All @@ -88,7 +88,6 @@ jobs:
${{ env.WORKDIR }}/.mypy_cache
key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }}


- name: Analysing the code with our lint
working-directory: ${{ inputs.working-directory }}
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class DocumentAIWarehouseRetriever(BaseRetriever):
If nothing is provided, all documents in the project will be searched."""
qa_size_limit: int = 5
"""The limit on the number of documents returned."""
client: "DocumentServiceClient" = None #: :meta private:
client: "DocumentServiceClient" = None # type:ignore[assignment] #: :meta private:

@model_validator(mode="before")
@classmethod
Expand Down
171 changes: 82 additions & 89 deletions libs/vertexai/langchain_google_vertexai/_anthropic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,74 @@ def _format_image(image_url: str) -> Dict:
}


def _format_message_anthropic(message: Union[HumanMessage, AIMessage]):
role = _message_type_lookups[message.type]
content: List[Dict[str, Any]] = []

if isinstance(message.content, str):
if not message.content.strip():
return None
content.append({"type": "text", "text": message.content})
elif isinstance(message.content, list):
for block in message.content:
if isinstance(block, str):
# Only add non-empty strings for now as empty ones are not
# accepted.
# https://github.com/anthropics/anthropic-sdk-python/issues/461
if not block.strip():
continue
content.append({"type": "text", "text": block})

if isinstance(block, dict):
if "type" not in block:
raise ValueError("Dict content block must have a type key")

new_block = {}

for copy_attr in ["type", "cache_control"]:
if copy_attr in block:
new_block[copy_attr] = block[copy_attr]

if block["type"] == "text":
text: str = block.get("text", "")
# Only add non-empty strings for now as empty ones are not
# accepted.
# https://github.com/anthropics/anthropic-sdk-python/issues/461
if text.strip():
new_block["text"] = text
content.append(new_block)
continue

if block["type"] == "image_url":
# convert format
new_block["source"] = _format_image(block["image_url"]["url"])
content.append(new_block)
continue

if block["type"] == "tool_use":
# If a tool_call with the same id as a tool_use content block
# exists, the tool_call is preferred.
if isinstance(message, AIMessage) and message.tool_calls:
is_unique = block["id"] not in [
tc["id"] for tc in message.tool_calls
]
if not is_unique:
continue

# all other block types
content.append(block)
else:
raise ValueError("Message should be a str, list of str or list of dicts")

# adding all tool calls
if isinstance(message, AIMessage) and message.tool_calls:
for tc in message.tool_calls:
tu = cast(Dict[str, Any], _lc_tool_call_to_anthropic_tool_use_block(tc))
content.append(tu)

return {"role": role, "content": content}


def _format_messages_anthropic(
messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict]]:
Expand All @@ -79,81 +147,11 @@ def _format_messages_anthropic(
system_message = message.content
continue

role = _message_type_lookups[message.type]
content: Union[str, List]

if not isinstance(message.content, str):
# parse as dict
assert isinstance(
message.content, list
), "Anthropic message content must be str or list of dicts"

# populate content
content = []
for item in message.content:
if isinstance(item, str):
content.append(
{
"type": "text",
"text": item,
}
)
elif isinstance(item, dict):
if "type" not in item:
raise ValueError("Dict content item must have a type key")
elif item["type"] == "image_url":
# convert format
source = _format_image(item["image_url"]["url"])
content.append(
{
"type": "image",
"source": source,
}
)
elif item["type"] == "tool_use":
# If a tool_call with the same id as a tool_use content block
# exists, the tool_call is preferred.
if isinstance(message, AIMessage) and item["id"] in [
tc["id"] for tc in message.tool_calls
]:
overlapping = [
tc
for tc in message.tool_calls
if tc["id"] == item["id"]
]
content.extend(
_lc_tool_calls_to_anthropic_tool_use_blocks(overlapping)
)
else:
item.pop("text", None)
content.append(item)
elif item["type"] == "text":
text = item.get("text", "")
# Only add non-empty strings for now as empty ones are not
# accepted.
# https://github.com/anthropics/anthropic-sdk-python/issues/461
if text.strip():
content.append({"type": "text", "text": text})
else:
content.append(item)
else:
raise ValueError(
f"Content items must be str or dict, instead was: {type(item)}"
)
elif isinstance(message, AIMessage) and message.tool_calls:
content = (
[]
if not message.content
else [{"type": "text", "text": message.content}]
)
# Note: Anthropic can't have invalid tool calls as presently defined,
# since the model already returns dicts args not JSON strings, and invalid
# tool calls are those with invalid JSON for args.
content += _lc_tool_calls_to_anthropic_tool_use_blocks(message.tool_calls)
else:
content = message.content
fm = _format_message_anthropic(message)
if not fm:
continue
formatted_messages.append(fm)

formatted_messages.append({"role": role, "content": content})
return system_message, formatted_messages


Expand Down Expand Up @@ -186,7 +184,7 @@ def _merge_messages(
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
merged: list = []
for curr in messages:
curr = curr.copy(deep=True)
curr = curr.model_copy(deep=True)
if isinstance(curr, ToolMessage):
if isinstance(curr.content, list) and all(
isinstance(block, dict) and block.get("type") == "tool_result"
Expand Down Expand Up @@ -226,20 +224,15 @@ class _AnthropicToolUse(TypedDict):
id: str


def _lc_tool_calls_to_anthropic_tool_use_blocks(
tool_calls: List[ToolCall],
) -> List[_AnthropicToolUse]:
blocks = []
for tool_call in tool_calls:
blocks.append(
_AnthropicToolUse(
type="tool_use",
name=tool_call["name"],
input=tool_call["args"],
id=cast(str, tool_call["id"]),
)
)
return blocks
def _lc_tool_call_to_anthropic_tool_use_block(
tool_call: ToolCall,
) -> _AnthropicToolUse:
return _AnthropicToolUse(
type="tool_use",
name=tool_call["name"],
input=tool_call["args"],
id=cast(str, tool_call["id"]),
)


def _make_message_chunk_from_anthropic_event(
Expand Down
Loading

0 comments on commit b33d7b6

Please sign in to comment.