From adadbd29ea7754610e898a15242449305dff47c8 Mon Sep 17 00:00:00 2001 From: kristoftabori <43764014+kristoftabori@users.noreply.github.com> Date: Thu, 24 Oct 2024 03:06:07 +0200 Subject: [PATCH] Feature/bedrock (#125) * adding Bedrock integration * adding langchain requirements to requirements.txt * adding basic tests for bedrock integrations * fixing typo in requirements.txt * extending venv ignoring * langchain import optional --------- Co-authored-by: Kristof Tabori --- .gitignore | 1 + tests/test_bedrock.py | 75 ++++++++++++++++++ textgrad/engine/bedrock.py | 155 +++++++++++++++++++++++++++++++++++++ 3 files changed, 231 insertions(+) create mode 100644 tests/test_bedrock.py create mode 100644 textgrad/engine/bedrock.py diff --git a/.gitignore b/.gitignore index c81bd44..df3774f 100644 --- a/.gitignore +++ b/.gitignore @@ -129,6 +129,7 @@ venv/ ENV/ env.bak/ venv.bak/ +*venv/ # Spyder project settings .spyderproject diff --git a/tests/test_bedrock.py b/tests/test_bedrock.py new file mode 100644 index 0000000..ce8d70c --- /dev/null +++ b/tests/test_bedrock.py @@ -0,0 +1,75 @@ +from unittest.mock import Mock + +import pytest +from botocore.client import BaseClient +from langchain_aws import ChatBedrock + +from textgrad.engine.bedrock import ChatBedrockEngine + + +@pytest.fixture +def mock_bedrock_client(): + return Mock(spec=BaseClient) + + +def test_chat_bedrock_engine_init_custom_values(mock_bedrock_client): + custom_model_kwargs = {"temperature": 0.7, "max_tokens": 1000} + custom_model_string = "anthropic.claude-3-haiku-20240307-v1:0" + custom_system_prompt = "You are the best AI assistant ever." + + engine = ChatBedrockEngine( + bedrock_client=mock_bedrock_client, + model_string=custom_model_string, + system_prompt=custom_system_prompt, + is_multimodal=True, + **custom_model_kwargs + ) + + assert isinstance(engine.client, ChatBedrock) + assert engine.model_string == custom_model_string + assert engine.system_prompt == custom_system_prompt + assert engine.is_multimodal is True + assert engine.kwargs == custom_model_kwargs + + +def test_chat_bedrock_engine_init_default_values(mock_bedrock_client): + engine = ChatBedrockEngine(bedrock_client=mock_bedrock_client) + + assert isinstance(engine.client, ChatBedrock) + assert engine.model_string == "anthropic.claude-3-sonnet-20240229-v1:0" + assert engine.system_prompt == ChatBedrockEngine.SYSTEM_PROMPT + assert engine.is_multimodal is False + assert engine.kwargs == {} + + +def test_chat_bedrock_engine_invalid_system_prompt(mock_bedrock_client): + with pytest.raises(AssertionError): + ChatBedrockEngine(bedrock_client=mock_bedrock_client, system_prompt=123) + + +def test_chat_bedrock_engine_call(mock_bedrock_client): + model_kwargs = {"temperature": 0.7, "max_tokens": 1000} + additional_kwargs = {"temperature": 0.8} + + engine = ChatBedrockEngine(bedrock_client=mock_bedrock_client, **model_kwargs) + + engine.generate = Mock(return_value="Mocked response") + + prompt = "Hello, how are you?" + response = engine(prompt, **additional_kwargs) + assert response == "Mocked response" + engine.generate.assert_called_with(prompt, max_tokens=1000, temperature=0.8) + + response = engine(prompt) + assert response == "Mocked response" + engine.generate.assert_called_with(prompt, max_tokens=1000, temperature=0.7) + + +def test_generate_with_string_input(mock_bedrock_client): + engine = ChatBedrockEngine(bedrock_client=mock_bedrock_client) + engine._generate_from_single_prompt = Mock(return_value="Mocked response") + + response = engine.generate("Hello, how are you?") + + assert response == "Mocked response" + engine._generate_from_single_prompt.assert_called_once_with("Hello, how are you?", system_prompt=None) \ No newline at end of file diff --git a/textgrad/engine/bedrock.py b/textgrad/engine/bedrock.py new file mode 100644 index 0000000..f49da44 --- /dev/null +++ b/textgrad/engine/bedrock.py @@ -0,0 +1,155 @@ +try: + from langchain_aws import ChatBedrock +except ImportError: + raise ImportError("If you'd like to use Bedrock models, please install the `langchain_aws` package by running `pip install langchain-aws`, and instantiate a Bedrock Client.") + +import base64 +import json +import os +from typing import Any, Dict, List, Union + +from botocore.client import BaseClient +from langchain_aws import ChatBedrock +from langchain_core.messages import HumanMessage, SystemMessage +from platformdirs import user_cache_dir +from tenacity import retry, stop_after_attempt, wait_random_exponential +from .base import CachedEngine, EngineLM +from .engine_utils import get_image_type_from_bytes + + +class ChatBedrockEngine(EngineLM, CachedEngine): + SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant" + + def __init__( + self, + bedrock_client: BaseClient, + model_string: str = "anthropic.claude-3-sonnet-20240229-v1:0", + system_prompt: str = SYSTEM_PROMPT, + is_multimodal: bool = False, + **kwargs: Any, + ): + root = user_cache_dir("textgrad") + cache_path = os.path.join(root, f"cache_bedrock_{model_string}.db") + super().__init__(cache_path=cache_path) + self.bedrock_client = bedrock_client + self.client = ChatBedrock( + client=bedrock_client, + model_id=model_string, + model_kwargs=kwargs, + ) + self.kwargs = kwargs + self.model_string = model_string + self.system_prompt = system_prompt + assert isinstance(self.system_prompt, str) + self.is_multimodal = is_multimodal + + def __call__(self, prompt: Union[str, List[Union[str, bytes]]], **kwargs): + passed_through_kwargs = self.kwargs.copy() + passed_through_kwargs.update(kwargs) + return self.generate(prompt, **passed_through_kwargs) + + @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5)) + def generate(self, content, system_prompt=None, **kwargs): + if isinstance(content, str): + return self._generate_from_single_prompt(content, system_prompt=system_prompt, **kwargs) + + elif isinstance(content, list): + has_multimodal_input = any(isinstance(item, bytes) for item in content) + if (has_multimodal_input) and (not self.is_multimodal): + raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.") + + return self._generate_from_multiple_input(content, system_prompt=system_prompt, **kwargs) + + def _generate_from_single_prompt(self, prompt: str, system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99): + + sys_prompt_arg = system_prompt if system_prompt else self.system_prompt + cache_or_none = self._check_cache(sys_prompt_arg + prompt) + if cache_or_none is not None: + return cache_or_none + + chat_client = self._update_chat_client(temperature=temperature, max_tokens=max_tokens, top_p=top_p) + + messages = [SystemMessage(content=sys_prompt_arg), HumanMessage(content=prompt)] + + response = chat_client.invoke(messages) + + response_text = str(response.content) + self._save_cache(sys_prompt_arg + prompt, response_text) + return response_text + + def _format_content(self, content: List[Union[str, bytes]]) -> List[Union[str, Dict[Any, Any]]]: + formatted_content: List[Union[str, Dict[Any, Any]]] = [] + for item in content: + if isinstance(item, bytes): + image_type = get_image_type_from_bytes(item) + + image_media_type = f"image/{image_type}" + base64_image = base64.b64encode(item).decode("utf-8") + formatted_content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": image_media_type, + "data": base64_image, + }, + } + ) + elif isinstance(item, str): + formatted_content.append({"type": "text", "text": item}) + else: + raise ValueError(f"Unsupported input type: {type(item)}") + return formatted_content + + def _generate_from_multiple_input( + self, + content: List[Union[str, bytes]], + system_prompt=None, + temperature=0, + max_tokens=2000, + top_p=0.99, + ): + sys_prompt_arg = system_prompt if system_prompt else self.system_prompt + formatted_content = self._format_content(content) + + cache_key = sys_prompt_arg + json.dumps(formatted_content) + cache_or_none = self._check_cache(cache_key) + if cache_or_none is not None: + return cache_or_none + + chat_client = self._update_chat_client(temperature=temperature, max_tokens=max_tokens, top_p=top_p) + + messages = [SystemMessage(content=sys_prompt_arg), HumanMessage(content=formatted_content)] + + response = chat_client.invoke(messages) + + response_text = str(response.content) + self._save_cache(cache_key, response_text) + return response_text + + def _update_chat_client( + self, + temperature, + max_tokens, + top_p, + ) -> ChatBedrock: + chat_client = self.client + + if any( + [ + self.kwargs.get("temperature", -1) != temperature, + self.kwargs.get("max_tokens", -1) != max_tokens, + self.kwargs.get("top_p", -1) != top_p, + ] + ): + updated_kwargs = self.kwargs.copy() + updated_kwargs["temperature"] = temperature + updated_kwargs["max_tokens"] = max_tokens + updated_kwargs["top_p"] = top_p + + chat_client = ChatBedrock( + client=self.bedrock_client, + model_id=self.model_string, + model_kwargs=updated_kwargs, + ) + return chat_client