Skip to content

Commit

Permalink
Feature/bedrock (#125)
Browse files Browse the repository at this point in the history
* adding Bedrock integration

* adding langchain requirements to requirements.txt

* adding basic tests for bedrock integrations

* fixing typo in requirements.txt

* extending venv ignoring

* langchain import optional

---------

Co-authored-by: Kristof Tabori <[email protected]>
  • Loading branch information
kristoftabori and datachristopher authored Oct 24, 2024
1 parent 60b328e commit adadbd2
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ venv/
ENV/
env.bak/
venv.bak/
*venv/

# Spyder project settings
.spyderproject
Expand Down
75 changes: 75 additions & 0 deletions tests/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from unittest.mock import Mock

import pytest
from botocore.client import BaseClient
from langchain_aws import ChatBedrock

from textgrad.engine.bedrock import ChatBedrockEngine


@pytest.fixture
def mock_bedrock_client():
return Mock(spec=BaseClient)


def test_chat_bedrock_engine_init_custom_values(mock_bedrock_client):
custom_model_kwargs = {"temperature": 0.7, "max_tokens": 1000}
custom_model_string = "anthropic.claude-3-haiku-20240307-v1:0"
custom_system_prompt = "You are the best AI assistant ever."

engine = ChatBedrockEngine(
bedrock_client=mock_bedrock_client,
model_string=custom_model_string,
system_prompt=custom_system_prompt,
is_multimodal=True,
**custom_model_kwargs
)

assert isinstance(engine.client, ChatBedrock)
assert engine.model_string == custom_model_string
assert engine.system_prompt == custom_system_prompt
assert engine.is_multimodal is True
assert engine.kwargs == custom_model_kwargs


def test_chat_bedrock_engine_init_default_values(mock_bedrock_client):
engine = ChatBedrockEngine(bedrock_client=mock_bedrock_client)

assert isinstance(engine.client, ChatBedrock)
assert engine.model_string == "anthropic.claude-3-sonnet-20240229-v1:0"
assert engine.system_prompt == ChatBedrockEngine.SYSTEM_PROMPT
assert engine.is_multimodal is False
assert engine.kwargs == {}


def test_chat_bedrock_engine_invalid_system_prompt(mock_bedrock_client):
with pytest.raises(AssertionError):
ChatBedrockEngine(bedrock_client=mock_bedrock_client, system_prompt=123)


def test_chat_bedrock_engine_call(mock_bedrock_client):
model_kwargs = {"temperature": 0.7, "max_tokens": 1000}
additional_kwargs = {"temperature": 0.8}

engine = ChatBedrockEngine(bedrock_client=mock_bedrock_client, **model_kwargs)

engine.generate = Mock(return_value="Mocked response")

prompt = "Hello, how are you?"
response = engine(prompt, **additional_kwargs)
assert response == "Mocked response"
engine.generate.assert_called_with(prompt, max_tokens=1000, temperature=0.8)

response = engine(prompt)
assert response == "Mocked response"
engine.generate.assert_called_with(prompt, max_tokens=1000, temperature=0.7)


def test_generate_with_string_input(mock_bedrock_client):
engine = ChatBedrockEngine(bedrock_client=mock_bedrock_client)
engine._generate_from_single_prompt = Mock(return_value="Mocked response")

response = engine.generate("Hello, how are you?")

assert response == "Mocked response"
engine._generate_from_single_prompt.assert_called_once_with("Hello, how are you?", system_prompt=None)
155 changes: 155 additions & 0 deletions textgrad/engine/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
try:
from langchain_aws import ChatBedrock
except ImportError:
raise ImportError("If you'd like to use Bedrock models, please install the `langchain_aws` package by running `pip install langchain-aws`, and instantiate a Bedrock Client.")

import base64
import json
import os
from typing import Any, Dict, List, Union

from botocore.client import BaseClient
from langchain_aws import ChatBedrock
from langchain_core.messages import HumanMessage, SystemMessage
from platformdirs import user_cache_dir
from tenacity import retry, stop_after_attempt, wait_random_exponential
from .base import CachedEngine, EngineLM
from .engine_utils import get_image_type_from_bytes


class ChatBedrockEngine(EngineLM, CachedEngine):
SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant"

def __init__(
self,
bedrock_client: BaseClient,
model_string: str = "anthropic.claude-3-sonnet-20240229-v1:0",
system_prompt: str = SYSTEM_PROMPT,
is_multimodal: bool = False,
**kwargs: Any,
):
root = user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_bedrock_{model_string}.db")
super().__init__(cache_path=cache_path)
self.bedrock_client = bedrock_client
self.client = ChatBedrock(
client=bedrock_client,
model_id=model_string,
model_kwargs=kwargs,
)
self.kwargs = kwargs
self.model_string = model_string
self.system_prompt = system_prompt
assert isinstance(self.system_prompt, str)
self.is_multimodal = is_multimodal

def __call__(self, prompt: Union[str, List[Union[str, bytes]]], **kwargs):
passed_through_kwargs = self.kwargs.copy()
passed_through_kwargs.update(kwargs)
return self.generate(prompt, **passed_through_kwargs)

@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5))
def generate(self, content, system_prompt=None, **kwargs):
if isinstance(content, str):
return self._generate_from_single_prompt(content, system_prompt=system_prompt, **kwargs)

elif isinstance(content, list):
has_multimodal_input = any(isinstance(item, bytes) for item in content)
if (has_multimodal_input) and (not self.is_multimodal):
raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.")

return self._generate_from_multiple_input(content, system_prompt=system_prompt, **kwargs)

def _generate_from_single_prompt(self, prompt: str, system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99):

sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
cache_or_none = self._check_cache(sys_prompt_arg + prompt)
if cache_or_none is not None:
return cache_or_none

chat_client = self._update_chat_client(temperature=temperature, max_tokens=max_tokens, top_p=top_p)

messages = [SystemMessage(content=sys_prompt_arg), HumanMessage(content=prompt)]

response = chat_client.invoke(messages)

response_text = str(response.content)
self._save_cache(sys_prompt_arg + prompt, response_text)
return response_text

def _format_content(self, content: List[Union[str, bytes]]) -> List[Union[str, Dict[Any, Any]]]:
formatted_content: List[Union[str, Dict[Any, Any]]] = []
for item in content:
if isinstance(item, bytes):
image_type = get_image_type_from_bytes(item)

image_media_type = f"image/{image_type}"
base64_image = base64.b64encode(item).decode("utf-8")
formatted_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": image_media_type,
"data": base64_image,
},
}
)
elif isinstance(item, str):
formatted_content.append({"type": "text", "text": item})
else:
raise ValueError(f"Unsupported input type: {type(item)}")
return formatted_content

def _generate_from_multiple_input(
self,
content: List[Union[str, bytes]],
system_prompt=None,
temperature=0,
max_tokens=2000,
top_p=0.99,
):
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
formatted_content = self._format_content(content)

cache_key = sys_prompt_arg + json.dumps(formatted_content)
cache_or_none = self._check_cache(cache_key)
if cache_or_none is not None:
return cache_or_none

chat_client = self._update_chat_client(temperature=temperature, max_tokens=max_tokens, top_p=top_p)

messages = [SystemMessage(content=sys_prompt_arg), HumanMessage(content=formatted_content)]

response = chat_client.invoke(messages)

response_text = str(response.content)
self._save_cache(cache_key, response_text)
return response_text

def _update_chat_client(
self,
temperature,
max_tokens,
top_p,
) -> ChatBedrock:
chat_client = self.client

if any(
[
self.kwargs.get("temperature", -1) != temperature,
self.kwargs.get("max_tokens", -1) != max_tokens,
self.kwargs.get("top_p", -1) != top_p,
]
):
updated_kwargs = self.kwargs.copy()
updated_kwargs["temperature"] = temperature
updated_kwargs["max_tokens"] = max_tokens
updated_kwargs["top_p"] = top_p

chat_client = ChatBedrock(
client=self.bedrock_client,
model_id=self.model_string,
model_kwargs=updated_kwargs,
)
return chat_client

0 comments on commit adadbd2

Please sign in to comment.