Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added ChatAnthropicVertex #119

Merged
merged 2 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions libs/vertexai/langchain_google_vertexai/_anthropic_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import re
from typing import Dict, List, Optional, Tuple, Union

from langchain_core.messages import BaseMessage

_message_type_lookups = {"human": "user", "ai": "assistant"}


def _format_image(image_url: str) -> Dict:
"""Formats a message image to a dict for anthropic api."""
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
match = re.match(regex, image_url)
if match is None:
raise ValueError(
"Anthropic only supports base64-encoded images currently."
" Example: data:image/png;base64,'/9j/4AAQSk'..."
)
return {
"type": "base64",
"media_type": match.group("media_type"),
"data": match.group("data"),
}


def _format_messages_anthropic(
messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict]]:
"""Formats messages for anthropic."""
system_message: Optional[str] = None
formatted_messages: List[Dict] = []

for i, message in enumerate(messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
if not isinstance(message.content, str):
raise ValueError(
"System message must be a string, "
f"instead was: {type(message.content)}"
)
system_message = message.content
continue

role = _message_type_lookups[message.type]
content: Union[str, List[Dict]]

if not isinstance(message.content, str):
# parse as dict
assert isinstance(
message.content, list
), "Anthropic message content must be str or list of dicts"

# populate content
content = []
for item in message.content:
if isinstance(item, str):
content.append(
{
"type": "text",
"text": item,
}
)
elif isinstance(item, dict):
if "type" not in item:
raise ValueError("Dict content item must have a type key")
elif item["type"] == "image_url":
# convert format
source = _format_image(item["image_url"]["url"])
content.append(
{
"type": "image",
"source": source,
}
)
elif item["type"] == "tool_use":
item.pop("text", None)
content.append(item)
else:
content.append(item)
else:
raise ValueError(
f"Content items must be str or dict, instead was: {type(item)}"
)
else:
content = message.content

formatted_messages.append(
{
"role": role,
"content": content,
}
)
return system_message, formatted_messages
166 changes: 163 additions & 3 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
from __future__ import annotations

import asyncio
from typing import Any, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
Generation,
LLMResult,
)
from langchain_core.pydantic_v1 import root_validator

from langchain_google_vertexai._base import _BaseVertexAIModelGarden
from langchain_google_vertexai._anthropic_utils import _format_messages_anthropic
from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon


class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM):
Expand Down Expand Up @@ -70,3 +88,145 @@ async def _agenerate(
endpoint=self.endpoint_path, instances=instances
)
return self._parse_response(response)


class ChatAnthropicVertex(_VertexAICommon, BaseChatModel):
async_client: Any = None #: :meta private:
model_name: Optional[str] = None # type: ignore[assignment]
"Underlying model name."
max_output_tokens: int = 1024

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
from anthropic import ( # type: ignore[import-not-found]
AnthropicVertex,
AsyncAnthropicVertex,
)

values["client"] = AnthropicVertex(
project_id=values["project"],
region=values["location"],
max_retries=values["max_retries"],
)
values["async_client"] = AsyncAnthropicVertex(
project_id=values["project"],
region=values["location"],
max_retries=values["max_retries"],
)
Comment on lines +101 to +115
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me think this makes more sense in langchain-anthropic - happy to discuss though! If it's co-dependent on both google-cloud-aiplatform and anthropic then it matters less.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather keep it here, since we might have other partner integrations on Vertex Model Garden eventually, so langchain-google-vertexai sounds like a more general place to me when people are looking for this integration. wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me. From what I can tell here: https://docs.anthropic.com/claude/reference/claude-on-vertex-ai

It looks like it requires both sdks

return values

@property
def _default_params(self):
return {
"model": self.model_name,
"max_tokens": self.max_output_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
}

def _format_params(
self,
*,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
system_message, formatted_messages = _format_messages_anthropic(messages)
params = self._default_params
params.update(kwargs)
if kwargs.get("model_name"):
params["model"] = params["model_name"]
if kwargs.get("model"):
params["model"] = kwargs["model"]
params.pop("model_name", None)
params.update(
{
"system": system_message,
"messages": formatted_messages,
"stop_sequences": stop,
}
)
return {k: v for k, v in params.items() if v is not None}

def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
data_dict = data.model_dump()
content = data_dict["content"]
llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
}
if len(content) == 1 and content[0]["type"] == "text":
msg = AIMessage(content=content[0]["text"])
else:
msg = AIMessage(content=content)
return ChatResult(
generations=[ChatGeneration(message=msg)],
llm_output=llm_output,
)

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
data = self.client.messages.create(**params)
return self._format_output(data, **kwargs)

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
data = await self.async_client.messages.create(**params)
return self._format_output(data, **kwargs)

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "anthropic-chat-vertexai"

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
with self.client.messages.stream(**params) as stream:
for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk

async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
async with self.async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
await run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
Loading
Loading