From 58858ce155efe72879900eb8264afb6943a8a7b3 Mon Sep 17 00:00:00 2001 From: findalexli Date: Tue, 10 Sep 2024 16:53:55 -0700 Subject: [PATCH 1/6] Integration test --- .../langchain_community/chat_models/reka.py | 188 ++++++++++++++++++ .../chat_models/test_reka.py | 82 ++++++++ 2 files changed, 270 insertions(+) create mode 100644 libs/community/langchain_community/chat_models/reka.py create mode 100644 libs/community/tests/integration_tests/chat_models/test_reka.py diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py new file mode 100644 index 0000000000000..99646281c67ea --- /dev/null +++ b/libs/community/langchain_community/chat_models/reka.py @@ -0,0 +1,188 @@ +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import ( + BaseChatModel, + agenerate_from_stream, + generate_from_stream, +) +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import Field, PrivateAttr +from langchain_core.utils import get_from_dict_or_env + +try: + from reka.client import Reka, AsyncReka +except ImportError: + raise ValueError( + "Reka is not installed. Please install it with `pip install reka-api`." + ) + +REKA_MODELS = [ + "reka-edge", + "reka-flash", + "reka-core", + "reka-core-20240501", +] + +DEFAULT_REKA_MODEL = "reka-core-20240501" + +def get_role(message: BaseMessage) -> str: + """Get the role of the message.""" + if isinstance(message, (ChatMessage, HumanMessage)): + return "user" + elif isinstance(message, AIMessage): + return "assistant" + elif isinstance(message, SystemMessage): + return "system" + else: + raise ValueError(f"Got unknown type {message}") + +def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str]]: + """Process messages for Reka format.""" + reka_messages = [] + system_message = None + + for message in messages: + if isinstance(message, SystemMessage): + if system_message is None: + system_message = message.content + else: + raise ValueError("Multiple system messages are not supported.") + else: + content = message.content + if system_message and isinstance(message, HumanMessage): + content = f"{system_message}\n{content}" + system_message = None + reka_messages.append({"role": get_role(message), "content": content}) + + return reka_messages + +class ChatReka(BaseChatModel): + """Reka chat large language models.""" + + model: str = Field(default=DEFAULT_REKA_MODEL, description="The Reka model to use.") + temperature: float = Field(default=0.7, description="The sampling temperature.") + max_tokens: int = Field(default=512, description="The maximum number of tokens to generate.") + api_key: str = Field(default=None, description="The API key for Reka.") + streaming: bool = Field(default=False, description="Whether to stream the response.") + + _client: Reka = PrivateAttr() + _aclient: AsyncReka = PrivateAttr() + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + api_key = get_from_dict_or_env(kwargs, "api_key", "REKA_API_KEY") + self._client = Reka(api_key=api_key) + self._aclient = AsyncReka(api_key=api_key) + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "reka-chat" + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling Reka API.""" + return { + "model": self.model, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + } + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + reka_messages = process_messages_for_reka(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + + stream = self._client.chat.create_stream(messages=reka_messages, **params) + + for chunk in stream: + content = chunk.responses[0].chunk.content + chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) + yield chunk + if run_manager: + run_manager.on_llm_new_token(content, chunk=chunk) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + reka_messages = process_messages_for_reka(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + + stream = await self._aclient.chat.create_stream(messages=reka_messages, **params) + + async for chunk in stream: + content = chunk.responses[0].chunk.content + chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) + yield chunk + if run_manager: + await run_manager.on_llm_new_token(content, chunk=chunk) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + return generate_from_stream( + self._stream(messages, stop=stop, run_manager=run_manager, **kwargs) + ) + + reka_messages = process_messages_for_reka(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + response = self._client.chat.create(messages=reka_messages, **params) + + message = AIMessage(content=response.responses[0].message.content) + return ChatResult(generations=[ChatGeneration(message=message)]) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + return await agenerate_from_stream( + self._astream(messages, stop=stop, run_manager=run_manager, **kwargs) + ) + + reka_messages = process_messages_for_reka(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + response = await self._aclient.chat.create(messages=reka_messages, **params) + + message = AIMessage(content=response.responses[0].message.content) + return ChatResult(generations=[ChatGeneration(message=message)]) + + def get_num_tokens(self, text: str) -> int: + """Calculate number of tokens.""" + raise NotImplementedError("Token counting is not implemented for Reka models.") diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py new file mode 100644 index 0000000000000..de9dc200710ae --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -0,0 +1,82 @@ +"""Test Reka API wrapper.""" + +from typing import List + +import pytest +from langchain_core.callbacks import CallbackManager +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.outputs import ChatGeneration, LLMResult + +from langchain_community.chat_models.reka import ChatReka +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + +@pytest.mark.scheduled +def test_reka_call() -> None: + """Test valid call to Reka.""" + chat = ChatReka(model="reka-flash") + message = HumanMessage(content="Hello") + response = chat.invoke([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + +@pytest.mark.scheduled +def test_reka_generate() -> None: + """Test generate method of Reka.""" + chat = ChatReka(model="reka-flash") + chat_messages: List[List[BaseMessage]] = [ + [HumanMessage(content="How many toes do dogs have?")] + ] + messages_copy = [messages.copy() for messages in chat_messages] + result: LLMResult = chat.generate(chat_messages) + assert isinstance(result, LLMResult) + for response in result.generations[0]: + assert isinstance(response, ChatGeneration) + assert isinstance(response.text, str) + assert response.text == response.message.content + assert chat_messages == messages_copy + +@pytest.mark.scheduled +def test_reka_streaming() -> None: + """Test streaming tokens from Reka.""" + chat = ChatReka(model="reka-flash", streaming=True) + message = HumanMessage(content="Hello") + response = chat.invoke([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + +@pytest.mark.scheduled +def test_reka_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatReka( + model="reka-flash", + streaming=True, + callback_manager=callback_manager, + verbose=True, + ) + message = HumanMessage(content="Write me a sentence with 10 words.") + chat.invoke([message]) + assert callback_handler.llm_streams > 1 + +@pytest.mark.scheduled +async def test_reka_async_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatReka( + model="reka-flash", + streaming=True, + callback_manager=callback_manager, + verbose=True, + ) + chat_messages: List[BaseMessage] = [ + HumanMessage(content="How many toes do dogs have?") + ] + result: LLMResult = await chat.agenerate([chat_messages]) + assert callback_handler.llm_streams > 1 + assert isinstance(result, LLMResult) + for response in result.generations[0]: + assert isinstance(response, ChatGeneration) + assert isinstance(response.text, str) + assert response.text == response.message.content \ No newline at end of file From 6f1d25949d6ab05c5a4a0a85756b92b861f3535d Mon Sep 17 00:00:00 2001 From: findalexli Date: Tue, 10 Sep 2024 16:59:24 -0700 Subject: [PATCH 2/6] Linting fix --- .../langchain_community/chat_models/reka.py | 33 +++++++++++++++---- .../chat_models/test_reka.py | 1 + 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index 99646281c67ea..ad037247616c2 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -1,4 +1,5 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional + from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -21,7 +22,7 @@ from langchain_core.utils import get_from_dict_or_env try: - from reka.client import Reka, AsyncReka + from reka.client import AsyncReka, Reka except ImportError: raise ValueError( "Reka is not installed. Please install it with `pip install reka-api`." @@ -70,11 +71,26 @@ def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str class ChatReka(BaseChatModel): """Reka chat large language models.""" - model: str = Field(default=DEFAULT_REKA_MODEL, description="The Reka model to use.") - temperature: float = Field(default=0.7, description="The sampling temperature.") - max_tokens: int = Field(default=512, description="The maximum number of tokens to generate.") - api_key: str = Field(default=None, description="The API key for Reka.") - streaming: bool = Field(default=False, description="Whether to stream the response.") + model: str = Field( + default=DEFAULT_REKA_MODEL, + description="The Reka model to use." + ) + temperature: float = Field( + default=0.7, + description="The sampling temperature." + ) + max_tokens: int = Field( + default=512, + description="The maximum number of tokens to generate." + ) + api_key: str = Field( + default=None, + description="The API key for Reka." + ) + streaming: bool = Field( + default=False, + description="Whether to stream the response." + ) _client: Reka = PrivateAttr() _aclient: AsyncReka = PrivateAttr() @@ -132,7 +148,10 @@ async def _astream( if stop: params["stop"] = stop - stream = await self._aclient.chat.create_stream(messages=reka_messages, **params) + stream = await self._aclient.chat.create_stream( + messages=reka_messages, + **params + ) async for chunk in stream: content = chunk.responses[0].chunk.content diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index de9dc200710ae..5f1b256c34477 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -10,6 +10,7 @@ from langchain_community.chat_models.reka import ChatReka from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + @pytest.mark.scheduled def test_reka_call() -> None: """Test valid call to Reka.""" From 5ecad6d13dd531ba1b842673ab58e9b9dfa88c9b Mon Sep 17 00:00:00 2001 From: findalexli Date: Wed, 11 Sep 2024 14:57:04 -0700 Subject: [PATCH 3/6] code formatter changes --- .../chat_models/__init__.py | 7 +- .../langchain_community/chat_models/reka.py | 145 ++++++++++++------ .../chat_models/test_reka.py | 6 +- .../tests/unit_tests/chat_models/test_reka.py | 123 +++++++++++++++ 4 files changed, 229 insertions(+), 52 deletions(-) create mode 100644 libs/community/tests/unit_tests/chat_models/test_reka.py diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index db5076375288f..8c68842272358 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -149,7 +149,10 @@ ) from langchain_community.chat_models.sambanova import ( ChatSambaNovaCloud, - ChatSambaStudio, + ChatSambaStudio + ) + from langchain_community.chat_models.reka import ( + ChatReka, ) from langchain_community.chat_models.snowflake import ( ChatSnowflakeCortex, @@ -214,6 +217,7 @@ "ChatOllama", "ChatOpenAI", "ChatPerplexity", + "ChatReka", "ChatPremAI", "ChatSambaNovaCloud", "ChatSambaStudio", @@ -274,6 +278,7 @@ "ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai", "ChatOllama": "langchain_community.chat_models.ollama", "ChatOpenAI": "langchain_community.chat_models.openai", + "ChatReka": "langchain_community.chat_models.reka", "ChatPerplexity": "langchain_community.chat_models.perplexity", "ChatSambaNovaCloud": "langchain_community.chat_models.sambanova", "ChatSambaStudio": "langchain_community.chat_models.sambanova", diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index ad037247616c2..77d68d5185fba 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -1,9 +1,10 @@ -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, @@ -18,8 +19,12 @@ SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Field, PrivateAttr -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import ( + get_from_dict_or_env, + get_pydantic_field_names, +) +from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str try: from reka.client import AsyncReka, Reka @@ -32,10 +37,10 @@ "reka-edge", "reka-flash", "reka-core", - "reka-core-20240501", ] -DEFAULT_REKA_MODEL = "reka-core-20240501" +DEFAULT_REKA_MODEL = "reka-flash" + def get_role(message: BaseMessage) -> str: """Get the role of the message.""" @@ -48,6 +53,7 @@ def get_role(message: BaseMessage) -> str: else: raise ValueError(f"Got unknown type {message}") + def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str]]: """Process messages for Reka format.""" reka_messages = [] @@ -68,52 +74,90 @@ def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str return reka_messages -class ChatReka(BaseChatModel): - """Reka chat large language models.""" - model: str = Field( - default=DEFAULT_REKA_MODEL, - description="The Reka model to use." - ) - temperature: float = Field( - default=0.7, - description="The sampling temperature." - ) - max_tokens: int = Field( - default=512, - description="The maximum number of tokens to generate." - ) - api_key: str = Field( - default=None, - description="The API key for Reka." - ) - streaming: bool = Field( - default=False, - description="Whether to stream the response." - ) +class RekaCommon(BaseLanguageModel): + client: Any = None #: :meta private: + async_client: Any = None #: :meta private: + model: str = Field(default=DEFAULT_REKA_MODEL, alias="model_name") + """Model name to use.""" - _client: Reka = PrivateAttr() - _aclient: AsyncReka = PrivateAttr() + max_tokens: int = Field(default=256) + """Denotes the number of tokens to predict per generation.""" - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - api_key = get_from_dict_or_env(kwargs, "api_key", "REKA_API_KEY") - self._client = Reka(api_key=api_key) - self._aclient = AsyncReka(api_key=api_key) + temperature: Optional[float] = None + """A non-negative float that tunes the degree of randomness in generation.""" - @property - def _llm_type(self) -> str: - """Return type of chat model.""" - return "reka-chat" + streaming: bool = False + """Whether to stream the results.""" + + default_request_timeout: Optional[float] = None + """Timeout for requests to Reka Completion API. Default is 600 seconds.""" + + max_retries: int = 2 + """Number of retries allowed for requests sent to the Reka Completion API.""" + + reka_api_key: Optional[SecretStr] = None + + count_tokens: Optional[Callable[[str], int]] = None + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + + @root_validator(pre=True) + def build_extra(cls, values: Dict) -> Dict: + extra = values.get("model_kwargs", {}) + all_required_field_names = get_pydantic_field_names(cls) + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) + return values + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["reka_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "reka_api_key", "REKA_API_KEY") + ) + + try: + from reka.client import AsyncReka, Reka + + values["client"] = Reka( + api_key=values["reka_api_key"].get_secret_value(), + ) + values["async_client"] = AsyncReka( + api_key=values["reka_api_key"].get_secret_value(), + ) + + except ImportError: + raise ImportError( + "Could not import reka python package. " + "Please install it with `pip install reka-api`." + ) + return values @property - def _default_params(self) -> Dict[str, Any]: + def _default_params(self) -> Mapping[str, Any]: """Get the default parameters for calling Reka API.""" - return { - "model": self.model, - "temperature": self.temperature, + d = { "max_tokens": self.max_tokens, + "model": self.model, } + if self.temperature is not None: + d["temperature"] = self.temperature + return {**d, **self.model_kwargs} + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{}, **self._default_params} + + +class ChatReka(BaseChatModel, RekaCommon): + """Reka chat large language models.""" + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "reka-chat" def _stream( self, @@ -127,7 +171,7 @@ def _stream( if stop: params["stop"] = stop - stream = self._client.chat.create_stream(messages=reka_messages, **params) + stream = self.client.chat.create_stream(messages=reka_messages, **params) for chunk in stream: content = chunk.responses[0].chunk.content @@ -148,10 +192,7 @@ async def _astream( if stop: params["stop"] = stop - stream = await self._aclient.chat.create_stream( - messages=reka_messages, - **params - ) + stream = self.async_client.chat.create_stream(messages=reka_messages, **params) async for chunk in stream: content = chunk.responses[0].chunk.content @@ -176,7 +217,7 @@ def _generate( params = {**self._default_params, **kwargs} if stop: params["stop"] = stop - response = self._client.chat.create(messages=reka_messages, **params) + response = self.client.chat.create(messages=reka_messages, **params) message = AIMessage(content=response.responses[0].message.content) return ChatResult(generations=[ChatGeneration(message=message)]) @@ -197,11 +238,15 @@ async def _agenerate( params = {**self._default_params, **kwargs} if stop: params["stop"] = stop - response = await self._aclient.chat.create(messages=reka_messages, **params) + response = await self.async_client.chat.create(messages=reka_messages, **params) message = AIMessage(content=response.responses[0].message.content) return ChatResult(generations=[ChatGeneration(message=message)]) def get_num_tokens(self, text: str) -> int: """Calculate number of tokens.""" - raise NotImplementedError("Token counting is not implemented for Reka models.") + if self.count_tokens is None: + raise NotImplementedError( + "get_num_tokens() is not implemented for Reka models." + ) + return self.count_tokens(text) diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index 5f1b256c34477..7b599cd7c2b48 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -20,6 +20,7 @@ def test_reka_call() -> None: assert isinstance(response, AIMessage) assert isinstance(response.content, str) + @pytest.mark.scheduled def test_reka_generate() -> None: """Test generate method of Reka.""" @@ -36,6 +37,7 @@ def test_reka_generate() -> None: assert response.text == response.message.content assert chat_messages == messages_copy + @pytest.mark.scheduled def test_reka_streaming() -> None: """Test streaming tokens from Reka.""" @@ -45,6 +47,7 @@ def test_reka_streaming() -> None: assert isinstance(response, AIMessage) assert isinstance(response.content, str) + @pytest.mark.scheduled def test_reka_streaming_callback() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" @@ -60,6 +63,7 @@ def test_reka_streaming_callback() -> None: chat.invoke([message]) assert callback_handler.llm_streams > 1 + @pytest.mark.scheduled async def test_reka_async_streaming_callback() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" @@ -80,4 +84,4 @@ async def test_reka_async_streaming_callback() -> None: for response in result.generations[0]: assert isinstance(response, ChatGeneration) assert isinstance(response.text, str) - assert response.text == response.message.content \ No newline at end of file + assert response.text == response.message.content diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py new file mode 100644 index 0000000000000..6c61d9a069fe4 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -0,0 +1,123 @@ +"""Test Reka Chat API wrapper.""" + +import os +from typing import List + +import pytest +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage + +from langchain_community.chat_models import ChatReka +from langchain_community.chat_models.reka import process_messages_for_reka + +os.environ["REKA_API_KEY"] = "dummy_key" + + +@pytest.mark.requires("reka") +def test_reka_model_name_param() -> None: + llm = ChatReka(model_name="reka-flash") + assert llm.model == "reka-flash" + + +@pytest.mark.requires("reka") +def test_reka_model_param() -> None: + llm = ChatReka(model="reka-flash") + assert llm.model == "reka-flash" + + +@pytest.mark.requires("reka") +def test_reka_model_kwargs() -> None: + llm = ChatReka(model_kwargs={"foo": "bar"}) + assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.mark.requires("reka") +def test_reka_invalid_model_kwargs() -> None: + with pytest.raises(ValueError): + ChatReka(model_kwargs={"max_tokens": "invalid"}) + + +@pytest.mark.requires("reka") +def test_reka_incorrect_field() -> None: + with pytest.warns(match="not default parameter"): + llm = ChatReka(foo="bar") + assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.mark.requires("reka") +def test_reka_initialization() -> None: + """Test Reka initialization.""" + # Verify that ChatReka can be initialized using a secret key provided + # as a parameter rather than an environment variable. + ChatReka(model="reka-flash", reka_api_key="test_key") + + +@pytest.mark.parametrize( + ("messages", "expected"), + [ + ([HumanMessage(content="Hello")], [{"role": "user", "content": "Hello"}]), + ( + [HumanMessage(content="Hello"), AIMessage(content="Hi there!")], + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + ), + ( + [ + SystemMessage(content="You're an assistant"), + HumanMessage(content="Hello"), + AIMessage(content="Hi there!"), + ], + [ + {"role": "user", "content": "You're an assistant\nHello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + ), + ], +) +def test_message_processing(messages: List[BaseMessage], expected: List[dict]) -> None: + result = process_messages_for_reka(messages) + assert result == expected + + +@pytest.mark.requires("reka") +def test_reka_streaming() -> None: + llm = ChatReka(streaming=True) + assert llm.streaming is True + + +@pytest.mark.requires("reka") +def test_reka_temperature() -> None: + llm = ChatReka(temperature=0.5) + assert llm.temperature == 0.5 + + +@pytest.mark.requires("reka") +def test_reka_max_tokens() -> None: + llm = ChatReka(max_tokens=100) + assert llm.max_tokens == 100 + + +@pytest.mark.requires("reka") +def test_reka_default_params() -> None: + llm = ChatReka() + assert llm._default_params == { + "max_tokens": 256, + "model": "reka-flash", + } + + +@pytest.mark.requires("reka") +def test_reka_identifying_params() -> None: + llm = ChatReka(temperature=0.7) + assert llm._identifying_params == { + "max_tokens": 256, + "model": "reka-flash", + "temperature": 0.7, + } + + +@pytest.mark.requires("reka") +def test_reka_llm_type() -> None: + llm = ChatReka() + assert llm._llm_type == "reka-chat" From e0637a5506dc1c6ee920e7feb4dc2d313fb92226 Mon Sep 17 00:00:00 2001 From: findalexli Date: Thu, 12 Sep 2024 14:29:14 -0700 Subject: [PATCH 4/6] Ruff linting check --- docs/docs/integrations/chat/reka.ipynb | 266 ++++++++++++++++++ .../langchain_community/chat_models/reka.py | 76 +++-- 2 files changed, 317 insertions(+), 25 deletions(-) create mode 100644 docs/docs/integrations/chat/reka.ipynb diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb new file mode 100644 index 0000000000000..eaf267252eceb --- /dev/null +++ b/docs/docs/integrations/chat/reka.ipynb @@ -0,0 +1,266 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models import ChatReka" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model = ChatReka()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "image_url = \"https://v0.docs.reka.ai/_images/000000245576.jpg\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "sidebar_label: Reka\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ChatReka\n", + "\n", + "This notebook provides a quick overview for getting started with Reka [chat models](/docs/concepts/#chat-models). \n", + "\n", + "Reka has several chat models. You can find information about their latest models and their costs, context windows, and supported input types in the [Reka docs](https://docs.reka.ai/available-models).\n", + "\n", + "\n", + "\n", + "\n", + "## Overview\n", + "### Integration details\n", + "\n", + "| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n", + "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", + "| [ChatReka] | [langchain_community](https://python.langchain.com/v0.2/api_reference/community/index.html) | ✅ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_community?style=flat-square&label=%20) |\n", + "\n", + "### Model features\n", + "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", + "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", + "| ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | \n", + "\n", + "## Setup\n", + "\n", + "To access Reka models you'll need to create an Reka developer account, get an API key, and install the `langchain_community` integration package and the reka python package via 'pip install reka-api'.\n", + "\n", + "### Credentials\n", + "\n", + "Head to https://platform.reka.ai/ to sign up for Reka and generate an API key. Once you've done this set the REKA_API_KEY environment variable:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "The LangChain __ModuleName__ integration lives in the `langchain_community` package:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qU langchain_community" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize a client" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "# os.environ[\"REKA_API_KEY\"] = getpass.getpass(\"Enter your Reka API key: \")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qU langchain_community" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models import ChatReka\n", + "# This \n", + "model = ChatReka()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Single turn text message" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=' Hello! How can I help you today? If you have any questions or need assistance with something, feel free to ask.\\n\\n', id='run-492dcda9-79b8-4e16-a81f-2f77b7d4e23a-0')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.invoke('hi')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Images input " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " The image you've uploaded is an indoor shot, and therefore, it does not provide any information about the weather outside. Since there are no windows or outdoor views depicted, it's not possible to determine the weather conditions at the time the photo was taken. If you need information about the weather, you should check a weather app, website, or news source based on your location or the location where the image was captured.\n" + ] + } + ], + "source": [ + "from langchain_community.chat_models import ChatReka\n", + "import httpx\n", + "\n", + "model = ChatReka()\n", + "image_url = \"https://v0.docs.reka.ai/_images/000000245576.jpg\"\n", + "from langchain_core.messages import HumanMessage\n", + "\n", + "message = HumanMessage(\n", + " content=[\n", + " {\"type\": \"text\", \"text\": \"describe the weather in this image\"},\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": image_url},\n", + " },\n", + " ],\n", + ")\n", + "response = model.invoke([message])\n", + "print(response.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multiple images as input" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " The two images are quite different in several ways. \n", + "\n", + "Firstly, the subjects of the images are different. The first image features a cat, while the second image features two German Shepherds, an adult and a puppy.\n", + "\n", + "Secondly, the actions depicted are different. In the first image, the cat is sitting still and looking directly at the camera. In contrast, the second image captures the two dogs in motion, running across a grassy field with a stick in the mouth of the adult dog.\n", + "\n", + "Lastly, the setting of the images is different. The first image has a neutral, blurred background, providing no specific context about the location. The second image, on the other hand, is set in an outdoor environment with a grassy field and trees in the background.\n", + "\n", + "These differences highlight the distinct characteristics and behaviors of the animals as well as the varying contexts in which they are photographed.\n" + ] + } + ], + "source": [ + "message = HumanMessage(\n", + " content=[\n", + " {\"type\": \"text\", \"text\": \"What are the difference between the two images? \"},\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": \"https://cdn.pixabay.com/photo/2019/07/23/13/51/shepherd-dog-4357790_1280.jpg\"},\n", + " },\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": \"https://cdn.pixabay.com/photo/2024/02/17/00/18/cat-8578562_1280.jpg\"},\n", + " },\n", + " ],\n", + ")\n", + "response = model.invoke([message])\n", + "print(response.content)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langchain_reka", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index 77d68d5185fba..40df5bf9492bd 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -1,4 +1,14 @@ -from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Union, +) from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -14,7 +24,6 @@ AIMessage, AIMessageChunk, BaseMessage, - ChatMessage, HumanMessage, SystemMessage, ) @@ -27,6 +36,7 @@ from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str try: + from reka import ChatMessage from reka.client import AsyncReka, Reka except ImportError: raise ValueError( @@ -42,20 +52,28 @@ DEFAULT_REKA_MODEL = "reka-flash" -def get_role(message: BaseMessage) -> str: - """Get the role of the message.""" - if isinstance(message, (ChatMessage, HumanMessage)): - return "user" - elif isinstance(message, AIMessage): - return "assistant" - elif isinstance(message, SystemMessage): - return "system" +def process_content_item(item: Dict[str, Any]) -> Dict[str, Any]: + """Process a single content item.""" + if item["type"] == "image_url": + image_url = item["image_url"] + if isinstance(image_url, dict) and "url" in image_url: + # If it's in LangChain format, extract the URL value + item["image_url"] = image_url["url"] + return item + + +def process_content(content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]: + """Process content to handle both text and media inputs, returning a list of content items.""" + if isinstance(content, str): + return [{"type": "text", "text": content}] + elif isinstance(content, list): + return [process_content_item(item) for item in content] else: - raise ValueError(f"Got unknown type {message}") + raise ValueError("Invalid content format") -def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str]]: - """Process messages for Reka format.""" +def convert_to_reka_messages(messages: List[Any]) -> List[ChatMessage]: + """Convert LangChain messages to Reka ChatMessage format.""" reka_messages = [] system_message = None @@ -65,12 +83,20 @@ def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str system_message = message.content else: raise ValueError("Multiple system messages are not supported.") - else: - content = message.content - if system_message and isinstance(message, HumanMessage): - content = f"{system_message}\n{content}" + elif isinstance(message, HumanMessage): + content = process_content(message.content) + if system_message: + if isinstance(content[0], dict) and content[0].get("type") == "text": + content[0]["text"] = f"{system_message}\n{content[0]['text']}" + else: + content.insert(0, {"type": "text", "text": system_message}) system_message = None - reka_messages.append({"role": get_role(message), "content": content}) + reka_messages.append(ChatMessage(content=content, role="user")) + elif isinstance(message, AIMessage): + content = process_content(message.content) + reka_messages.append(ChatMessage(content=content, role="assistant")) + else: + raise ValueError(f"Unsupported message type: {type(message)}") return reka_messages @@ -118,8 +144,6 @@ def validate_environment(cls, values: Dict) -> Dict: ) try: - from reka.client import AsyncReka, Reka - values["client"] = Reka( api_key=values["reka_api_key"].get_secret_value(), ) @@ -166,7 +190,7 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - reka_messages = process_messages_for_reka(messages) + reka_messages = convert_to_reka_messages(messages) params = {**self._default_params, **kwargs} if stop: params["stop"] = stop @@ -187,12 +211,14 @@ async def _astream( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - reka_messages = process_messages_for_reka(messages) + reka_messages = convert_to_reka_messages(messages) params = {**self._default_params, **kwargs} if stop: params["stop"] = stop - stream = self.async_client.chat.create_stream(messages=reka_messages, **params) + stream = await self.async_client.chat.create_stream( + messages=reka_messages, **params + ) async for chunk in stream: content = chunk.responses[0].chunk.content @@ -213,7 +239,7 @@ def _generate( self._stream(messages, stop=stop, run_manager=run_manager, **kwargs) ) - reka_messages = process_messages_for_reka(messages) + reka_messages = convert_to_reka_messages(messages) params = {**self._default_params, **kwargs} if stop: params["stop"] = stop @@ -234,7 +260,7 @@ async def _agenerate( self._astream(messages, stop=stop, run_manager=run_manager, **kwargs) ) - reka_messages = process_messages_for_reka(messages) + reka_messages = convert_to_reka_messages(messages) params = {**self._default_params, **kwargs} if stop: params["stop"] = stop From e93631747228b0a25a86a72c9c777d5934e96ac1 Mon Sep 17 00:00:00 2001 From: findalexli Date: Thu, 12 Sep 2024 15:03:59 -0700 Subject: [PATCH 5/6] Unit test update --- .../tests/unit_tests/chat_models/test_reka.py | 84 +++++++++++++++---- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 6c61d9a069fe4..eca1a948b8b22 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -1,4 +1,4 @@ -"""Test Reka Chat API wrapper.""" +"""Test Reka Chat wrapper.""" import os from typing import List @@ -7,7 +7,10 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_community.chat_models import ChatReka -from langchain_community.chat_models.reka import process_messages_for_reka +from langchain_community.chat_models.reka import ( + convert_to_reka_messages, + process_content, +) os.environ["REKA_API_KEY"] = "dummy_key" @@ -52,34 +55,87 @@ def test_reka_initialization() -> None: @pytest.mark.parametrize( - ("messages", "expected"), + ("content", "expected"), [ - ([HumanMessage(content="Hello")], [{"role": "user", "content": "Hello"}]), + ("Hello", [{"type": "text", "text": "Hello"}]), ( - [HumanMessage(content="Hello"), AIMessage(content="Hi there!")], [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": "https://example.com/image.jpg"}, + ], + [ + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": "https://example.com/image.jpg"}, ], ), ( [ - SystemMessage(content="You're an assistant"), - HumanMessage(content="Hello"), - AIMessage(content="Hi there!"), + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.jpg"}, + }, ], [ - {"role": "user", "content": "You're an assistant\nHello"}, - {"role": "assistant", "content": "Hi there!"}, + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": "https://example.com/image.jpg"}, ], ), ], ) -def test_message_processing(messages: List[BaseMessage], expected: List[dict]) -> None: - result = process_messages_for_reka(messages) +def test_process_content(content, expected) -> None: + result = process_content(content) assert result == expected +@pytest.mark.parametrize( + ("messages", "expected"), + [ + ( + [HumanMessage(content="Hello")], + [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + ), + ( + [ + HumanMessage( + content=[ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ] + ), + AIMessage(content="It's a beautiful landscape."), + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "It's a beautiful landscape."} + ], + }, + ], + ), + ], +) +def test_convert_to_reka_messages( + messages: List[BaseMessage], expected: List[dict] +) -> None: + result = convert_to_reka_messages(messages) + assert [message.dict() for message in result] == expected + + @pytest.mark.requires("reka") def test_reka_streaming() -> None: llm = ChatReka(streaming=True) From a7fe379c207f1723830aa4680240dbcdaff63c4f Mon Sep 17 00:00:00 2001 From: findalexli Date: Thu, 10 Oct 2024 16:39:55 -0700 Subject: [PATCH 6/6] Included dependecy --- libs/community/extended_testing_deps.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 3cdf3d52c632a..d8dc0362cb639 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -72,6 +72,7 @@ rapidfuzz>=3.1.1,<4 rapidocr-onnxruntime>=1.3.2,<2 rdflib==7.0.0 requests-toolbelt>=1.0.0,<2 +reka-api>=3.2.0 rspace_client>=2.5.0,<3 scikit-learn>=1.2.2,<2 simsimd>=5.0.0,<6