Skip to content

Commit

Permalink
added ChatAnthropicVertex
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Apr 5, 2024
1 parent ca3724a commit 2df4405
Show file tree
Hide file tree
Showing 5 changed files with 864 additions and 205 deletions.
93 changes: 93 additions & 0 deletions libs/vertexai/langchain_google_vertexai/_anthropic_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import re
from typing import Dict, List, Optional, Tuple, Union

from langchain_core.messages import BaseMessage

_message_type_lookups = {"human": "user", "ai": "assistant"}


def _format_image(image_url: str) -> Dict:
"""Formats a message image to a dict for anthropic api."""
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
match = re.match(regex, image_url)
if match is None:
raise ValueError(
"Anthropic only supports base64-encoded images currently."
" Example: data:image/png;base64,'/9j/4AAQSk'..."
)
return {
"type": "base64",
"media_type": match.group("media_type"),
"data": match.group("data"),
}


def _format_messages_anthropic(
messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict]]:
"""Formats messages for anthropic."""
system_message: Optional[str] = None
formatted_messages: List[Dict] = []

for i, message in enumerate(messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
if not isinstance(message.content, str):
raise ValueError(
"System message must be a string, "
f"instead was: {type(message.content)}"
)
system_message = message.content
continue

role = _message_type_lookups[message.type]
content: Union[str, List[Dict]]

if not isinstance(message.content, str):
# parse as dict
assert isinstance(
message.content, list
), "Anthropic message content must be str or list of dicts"

# populate content
content = []
for item in message.content:
if isinstance(item, str):
content.append(
{
"type": "text",
"text": item,
}
)
elif isinstance(item, dict):
if "type" not in item:
raise ValueError("Dict content item must have a type key")
elif item["type"] == "image_url":
# convert format
source = _format_image(item["image_url"]["url"])
content.append(
{
"type": "image",
"source": source,
}
)
elif item["type"] == "tool_use":
item.pop("text", None)
content.append(item)
else:
content.append(item)
else:
raise ValueError(
f"Content items must be str or dict, instead was: {type(item)}"
)
else:
content = message.content

formatted_messages.append(
{
"role": role,
"content": content,
}
)
return system_message, formatted_messages
161 changes: 158 additions & 3 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
from __future__ import annotations

import asyncio
from typing import Any, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
Generation,
LLMResult,
)
from langchain_core.pydantic_v1 import root_validator

from langchain_google_vertexai._base import _BaseVertexAIModelGarden
from langchain_google_vertexai._anthropic_utils import _format_messages_anthropic
from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon


class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM):
Expand Down Expand Up @@ -70,3 +88,140 @@ async def _agenerate(
endpoint=self.endpoint_path, instances=instances
)
return self._parse_response(response)


class ChatAnthropicVertex(_VertexAICommon, BaseChatModel):
async_client: Any = None #: :meta private:
model_name: str = "claude-3-sonnet@20240229"
"Underlying model name."
max_output_tokens: int = 1024

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
from anthropic import ( # type: ignore[import]
AnthropicVertex,
AsyncAnthropicVertex,
)

values["client"] = AnthropicVertex(
project_id=values["project"],
region=values["location"],
max_retries=values["max_retries"],
)
values["async_client"] = AsyncAnthropicVertex(
project_id=values["project"],
region=values["location"],
max_retries=values["max_retries"],
)
return values

@property
def _default_params(self):
return {
"model": self.model_name,
"max_tokens": self.max_output_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
}

def _format_params(
self,
*,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
system_message, formatted_messages = _format_messages_anthropic(messages)
params = self._default_params
params.update(kwargs)
params.update(
{
"system": system_message,
"messages": formatted_messages,
"stop_sequences": stop,
}
)
return {k: v for k, v in params.items() if v is not None}

def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
data_dict = data.model_dump()
content = data_dict["content"]
llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
}
if len(content) == 1 and content[0]["type"] == "text":
msg = AIMessage(content=content[0]["text"])
else:
msg = AIMessage(content=content)
return ChatResult(
generations=[ChatGeneration(message=msg)],
llm_output=llm_output,
)

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
data = self.client.messages.create(**params)
return self._format_output(data, **kwargs)

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
data = await self.async_client.messages.create(**params)
return self._format_output(data, **kwargs)

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "anthropic-chat-vertexai"

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
with self.client.messages.stream(**params) as stream:
for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk

async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
async with self.async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
await run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
Loading

0 comments on commit 2df4405

Please sign in to comment.