From 58858ce155efe72879900eb8264afb6943a8a7b3 Mon Sep 17 00:00:00 2001 From: findalexli Date: Tue, 10 Sep 2024 16:53:55 -0700 Subject: [PATCH 01/36] Integration test --- .../langchain_community/chat_models/reka.py | 188 ++++++++++++++++++ .../chat_models/test_reka.py | 82 ++++++++ 2 files changed, 270 insertions(+) create mode 100644 libs/community/langchain_community/chat_models/reka.py create mode 100644 libs/community/tests/integration_tests/chat_models/test_reka.py diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py new file mode 100644 index 0000000000000..99646281c67ea --- /dev/null +++ b/libs/community/langchain_community/chat_models/reka.py @@ -0,0 +1,188 @@ +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import ( + BaseChatModel, + agenerate_from_stream, + generate_from_stream, +) +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import Field, PrivateAttr +from langchain_core.utils import get_from_dict_or_env + +try: + from reka.client import Reka, AsyncReka +except ImportError: + raise ValueError( + "Reka is not installed. Please install it with `pip install reka-api`." + ) + +REKA_MODELS = [ + "reka-edge", + "reka-flash", + "reka-core", + "reka-core-20240501", +] + +DEFAULT_REKA_MODEL = "reka-core-20240501" + +def get_role(message: BaseMessage) -> str: + """Get the role of the message.""" + if isinstance(message, (ChatMessage, HumanMessage)): + return "user" + elif isinstance(message, AIMessage): + return "assistant" + elif isinstance(message, SystemMessage): + return "system" + else: + raise ValueError(f"Got unknown type {message}") + +def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str]]: + """Process messages for Reka format.""" + reka_messages = [] + system_message = None + + for message in messages: + if isinstance(message, SystemMessage): + if system_message is None: + system_message = message.content + else: + raise ValueError("Multiple system messages are not supported.") + else: + content = message.content + if system_message and isinstance(message, HumanMessage): + content = f"{system_message}\n{content}" + system_message = None + reka_messages.append({"role": get_role(message), "content": content}) + + return reka_messages + +class ChatReka(BaseChatModel): + """Reka chat large language models.""" + + model: str = Field(default=DEFAULT_REKA_MODEL, description="The Reka model to use.") + temperature: float = Field(default=0.7, description="The sampling temperature.") + max_tokens: int = Field(default=512, description="The maximum number of tokens to generate.") + api_key: str = Field(default=None, description="The API key for Reka.") + streaming: bool = Field(default=False, description="Whether to stream the response.") + + _client: Reka = PrivateAttr() + _aclient: AsyncReka = PrivateAttr() + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + api_key = get_from_dict_or_env(kwargs, "api_key", "REKA_API_KEY") + self._client = Reka(api_key=api_key) + self._aclient = AsyncReka(api_key=api_key) + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "reka-chat" + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling Reka API.""" + return { + "model": self.model, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + } + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + reka_messages = process_messages_for_reka(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + + stream = self._client.chat.create_stream(messages=reka_messages, **params) + + for chunk in stream: + content = chunk.responses[0].chunk.content + chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) + yield chunk + if run_manager: + run_manager.on_llm_new_token(content, chunk=chunk) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + reka_messages = process_messages_for_reka(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + + stream = await self._aclient.chat.create_stream(messages=reka_messages, **params) + + async for chunk in stream: + content = chunk.responses[0].chunk.content + chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) + yield chunk + if run_manager: + await run_manager.on_llm_new_token(content, chunk=chunk) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + return generate_from_stream( + self._stream(messages, stop=stop, run_manager=run_manager, **kwargs) + ) + + reka_messages = process_messages_for_reka(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + response = self._client.chat.create(messages=reka_messages, **params) + + message = AIMessage(content=response.responses[0].message.content) + return ChatResult(generations=[ChatGeneration(message=message)]) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + return await agenerate_from_stream( + self._astream(messages, stop=stop, run_manager=run_manager, **kwargs) + ) + + reka_messages = process_messages_for_reka(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + response = await self._aclient.chat.create(messages=reka_messages, **params) + + message = AIMessage(content=response.responses[0].message.content) + return ChatResult(generations=[ChatGeneration(message=message)]) + + def get_num_tokens(self, text: str) -> int: + """Calculate number of tokens.""" + raise NotImplementedError("Token counting is not implemented for Reka models.") diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py new file mode 100644 index 0000000000000..de9dc200710ae --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -0,0 +1,82 @@ +"""Test Reka API wrapper.""" + +from typing import List + +import pytest +from langchain_core.callbacks import CallbackManager +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.outputs import ChatGeneration, LLMResult + +from langchain_community.chat_models.reka import ChatReka +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + +@pytest.mark.scheduled +def test_reka_call() -> None: + """Test valid call to Reka.""" + chat = ChatReka(model="reka-flash") + message = HumanMessage(content="Hello") + response = chat.invoke([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + +@pytest.mark.scheduled +def test_reka_generate() -> None: + """Test generate method of Reka.""" + chat = ChatReka(model="reka-flash") + chat_messages: List[List[BaseMessage]] = [ + [HumanMessage(content="How many toes do dogs have?")] + ] + messages_copy = [messages.copy() for messages in chat_messages] + result: LLMResult = chat.generate(chat_messages) + assert isinstance(result, LLMResult) + for response in result.generations[0]: + assert isinstance(response, ChatGeneration) + assert isinstance(response.text, str) + assert response.text == response.message.content + assert chat_messages == messages_copy + +@pytest.mark.scheduled +def test_reka_streaming() -> None: + """Test streaming tokens from Reka.""" + chat = ChatReka(model="reka-flash", streaming=True) + message = HumanMessage(content="Hello") + response = chat.invoke([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + +@pytest.mark.scheduled +def test_reka_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatReka( + model="reka-flash", + streaming=True, + callback_manager=callback_manager, + verbose=True, + ) + message = HumanMessage(content="Write me a sentence with 10 words.") + chat.invoke([message]) + assert callback_handler.llm_streams > 1 + +@pytest.mark.scheduled +async def test_reka_async_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatReka( + model="reka-flash", + streaming=True, + callback_manager=callback_manager, + verbose=True, + ) + chat_messages: List[BaseMessage] = [ + HumanMessage(content="How many toes do dogs have?") + ] + result: LLMResult = await chat.agenerate([chat_messages]) + assert callback_handler.llm_streams > 1 + assert isinstance(result, LLMResult) + for response in result.generations[0]: + assert isinstance(response, ChatGeneration) + assert isinstance(response.text, str) + assert response.text == response.message.content \ No newline at end of file From 6f1d25949d6ab05c5a4a0a85756b92b861f3535d Mon Sep 17 00:00:00 2001 From: findalexli Date: Tue, 10 Sep 2024 16:59:24 -0700 Subject: [PATCH 02/36] Linting fix --- .../langchain_community/chat_models/reka.py | 33 +++++++++++++++---- .../chat_models/test_reka.py | 1 + 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index 99646281c67ea..ad037247616c2 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -1,4 +1,5 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional + from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -21,7 +22,7 @@ from langchain_core.utils import get_from_dict_or_env try: - from reka.client import Reka, AsyncReka + from reka.client import AsyncReka, Reka except ImportError: raise ValueError( "Reka is not installed. Please install it with `pip install reka-api`." @@ -70,11 +71,26 @@ def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str class ChatReka(BaseChatModel): """Reka chat large language models.""" - model: str = Field(default=DEFAULT_REKA_MODEL, description="The Reka model to use.") - temperature: float = Field(default=0.7, description="The sampling temperature.") - max_tokens: int = Field(default=512, description="The maximum number of tokens to generate.") - api_key: str = Field(default=None, description="The API key for Reka.") - streaming: bool = Field(default=False, description="Whether to stream the response.") + model: str = Field( + default=DEFAULT_REKA_MODEL, + description="The Reka model to use." + ) + temperature: float = Field( + default=0.7, + description="The sampling temperature." + ) + max_tokens: int = Field( + default=512, + description="The maximum number of tokens to generate." + ) + api_key: str = Field( + default=None, + description="The API key for Reka." + ) + streaming: bool = Field( + default=False, + description="Whether to stream the response." + ) _client: Reka = PrivateAttr() _aclient: AsyncReka = PrivateAttr() @@ -132,7 +148,10 @@ async def _astream( if stop: params["stop"] = stop - stream = await self._aclient.chat.create_stream(messages=reka_messages, **params) + stream = await self._aclient.chat.create_stream( + messages=reka_messages, + **params + ) async for chunk in stream: content = chunk.responses[0].chunk.content diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index de9dc200710ae..5f1b256c34477 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -10,6 +10,7 @@ from langchain_community.chat_models.reka import ChatReka from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + @pytest.mark.scheduled def test_reka_call() -> None: """Test valid call to Reka.""" From 5ecad6d13dd531ba1b842673ab58e9b9dfa88c9b Mon Sep 17 00:00:00 2001 From: findalexli Date: Wed, 11 Sep 2024 14:57:04 -0700 Subject: [PATCH 03/36] code formatter changes --- .../chat_models/__init__.py | 7 +- .../langchain_community/chat_models/reka.py | 145 ++++++++++++------ .../chat_models/test_reka.py | 6 +- .../tests/unit_tests/chat_models/test_reka.py | 123 +++++++++++++++ 4 files changed, 229 insertions(+), 52 deletions(-) create mode 100644 libs/community/tests/unit_tests/chat_models/test_reka.py diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index db5076375288f..8c68842272358 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -149,7 +149,10 @@ ) from langchain_community.chat_models.sambanova import ( ChatSambaNovaCloud, - ChatSambaStudio, + ChatSambaStudio + ) + from langchain_community.chat_models.reka import ( + ChatReka, ) from langchain_community.chat_models.snowflake import ( ChatSnowflakeCortex, @@ -214,6 +217,7 @@ "ChatOllama", "ChatOpenAI", "ChatPerplexity", + "ChatReka", "ChatPremAI", "ChatSambaNovaCloud", "ChatSambaStudio", @@ -274,6 +278,7 @@ "ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai", "ChatOllama": "langchain_community.chat_models.ollama", "ChatOpenAI": "langchain_community.chat_models.openai", + "ChatReka": "langchain_community.chat_models.reka", "ChatPerplexity": "langchain_community.chat_models.perplexity", "ChatSambaNovaCloud": "langchain_community.chat_models.sambanova", "ChatSambaStudio": "langchain_community.chat_models.sambanova", diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index ad037247616c2..77d68d5185fba 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -1,9 +1,10 @@ -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, @@ -18,8 +19,12 @@ SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Field, PrivateAttr -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import ( + get_from_dict_or_env, + get_pydantic_field_names, +) +from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str try: from reka.client import AsyncReka, Reka @@ -32,10 +37,10 @@ "reka-edge", "reka-flash", "reka-core", - "reka-core-20240501", ] -DEFAULT_REKA_MODEL = "reka-core-20240501" +DEFAULT_REKA_MODEL = "reka-flash" + def get_role(message: BaseMessage) -> str: """Get the role of the message.""" @@ -48,6 +53,7 @@ def get_role(message: BaseMessage) -> str: else: raise ValueError(f"Got unknown type {message}") + def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str]]: """Process messages for Reka format.""" reka_messages = [] @@ -68,52 +74,90 @@ def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str return reka_messages -class ChatReka(BaseChatModel): - """Reka chat large language models.""" - model: str = Field( - default=DEFAULT_REKA_MODEL, - description="The Reka model to use." - ) - temperature: float = Field( - default=0.7, - description="The sampling temperature." - ) - max_tokens: int = Field( - default=512, - description="The maximum number of tokens to generate." - ) - api_key: str = Field( - default=None, - description="The API key for Reka." - ) - streaming: bool = Field( - default=False, - description="Whether to stream the response." - ) +class RekaCommon(BaseLanguageModel): + client: Any = None #: :meta private: + async_client: Any = None #: :meta private: + model: str = Field(default=DEFAULT_REKA_MODEL, alias="model_name") + """Model name to use.""" - _client: Reka = PrivateAttr() - _aclient: AsyncReka = PrivateAttr() + max_tokens: int = Field(default=256) + """Denotes the number of tokens to predict per generation.""" - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - api_key = get_from_dict_or_env(kwargs, "api_key", "REKA_API_KEY") - self._client = Reka(api_key=api_key) - self._aclient = AsyncReka(api_key=api_key) + temperature: Optional[float] = None + """A non-negative float that tunes the degree of randomness in generation.""" - @property - def _llm_type(self) -> str: - """Return type of chat model.""" - return "reka-chat" + streaming: bool = False + """Whether to stream the results.""" + + default_request_timeout: Optional[float] = None + """Timeout for requests to Reka Completion API. Default is 600 seconds.""" + + max_retries: int = 2 + """Number of retries allowed for requests sent to the Reka Completion API.""" + + reka_api_key: Optional[SecretStr] = None + + count_tokens: Optional[Callable[[str], int]] = None + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + + @root_validator(pre=True) + def build_extra(cls, values: Dict) -> Dict: + extra = values.get("model_kwargs", {}) + all_required_field_names = get_pydantic_field_names(cls) + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) + return values + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["reka_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "reka_api_key", "REKA_API_KEY") + ) + + try: + from reka.client import AsyncReka, Reka + + values["client"] = Reka( + api_key=values["reka_api_key"].get_secret_value(), + ) + values["async_client"] = AsyncReka( + api_key=values["reka_api_key"].get_secret_value(), + ) + + except ImportError: + raise ImportError( + "Could not import reka python package. " + "Please install it with `pip install reka-api`." + ) + return values @property - def _default_params(self) -> Dict[str, Any]: + def _default_params(self) -> Mapping[str, Any]: """Get the default parameters for calling Reka API.""" - return { - "model": self.model, - "temperature": self.temperature, + d = { "max_tokens": self.max_tokens, + "model": self.model, } + if self.temperature is not None: + d["temperature"] = self.temperature + return {**d, **self.model_kwargs} + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{}, **self._default_params} + + +class ChatReka(BaseChatModel, RekaCommon): + """Reka chat large language models.""" + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "reka-chat" def _stream( self, @@ -127,7 +171,7 @@ def _stream( if stop: params["stop"] = stop - stream = self._client.chat.create_stream(messages=reka_messages, **params) + stream = self.client.chat.create_stream(messages=reka_messages, **params) for chunk in stream: content = chunk.responses[0].chunk.content @@ -148,10 +192,7 @@ async def _astream( if stop: params["stop"] = stop - stream = await self._aclient.chat.create_stream( - messages=reka_messages, - **params - ) + stream = self.async_client.chat.create_stream(messages=reka_messages, **params) async for chunk in stream: content = chunk.responses[0].chunk.content @@ -176,7 +217,7 @@ def _generate( params = {**self._default_params, **kwargs} if stop: params["stop"] = stop - response = self._client.chat.create(messages=reka_messages, **params) + response = self.client.chat.create(messages=reka_messages, **params) message = AIMessage(content=response.responses[0].message.content) return ChatResult(generations=[ChatGeneration(message=message)]) @@ -197,11 +238,15 @@ async def _agenerate( params = {**self._default_params, **kwargs} if stop: params["stop"] = stop - response = await self._aclient.chat.create(messages=reka_messages, **params) + response = await self.async_client.chat.create(messages=reka_messages, **params) message = AIMessage(content=response.responses[0].message.content) return ChatResult(generations=[ChatGeneration(message=message)]) def get_num_tokens(self, text: str) -> int: """Calculate number of tokens.""" - raise NotImplementedError("Token counting is not implemented for Reka models.") + if self.count_tokens is None: + raise NotImplementedError( + "get_num_tokens() is not implemented for Reka models." + ) + return self.count_tokens(text) diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index 5f1b256c34477..7b599cd7c2b48 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -20,6 +20,7 @@ def test_reka_call() -> None: assert isinstance(response, AIMessage) assert isinstance(response.content, str) + @pytest.mark.scheduled def test_reka_generate() -> None: """Test generate method of Reka.""" @@ -36,6 +37,7 @@ def test_reka_generate() -> None: assert response.text == response.message.content assert chat_messages == messages_copy + @pytest.mark.scheduled def test_reka_streaming() -> None: """Test streaming tokens from Reka.""" @@ -45,6 +47,7 @@ def test_reka_streaming() -> None: assert isinstance(response, AIMessage) assert isinstance(response.content, str) + @pytest.mark.scheduled def test_reka_streaming_callback() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" @@ -60,6 +63,7 @@ def test_reka_streaming_callback() -> None: chat.invoke([message]) assert callback_handler.llm_streams > 1 + @pytest.mark.scheduled async def test_reka_async_streaming_callback() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" @@ -80,4 +84,4 @@ async def test_reka_async_streaming_callback() -> None: for response in result.generations[0]: assert isinstance(response, ChatGeneration) assert isinstance(response.text, str) - assert response.text == response.message.content \ No newline at end of file + assert response.text == response.message.content diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py new file mode 100644 index 0000000000000..6c61d9a069fe4 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -0,0 +1,123 @@ +"""Test Reka Chat API wrapper.""" + +import os +from typing import List + +import pytest +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage + +from langchain_community.chat_models import ChatReka +from langchain_community.chat_models.reka import process_messages_for_reka + +os.environ["REKA_API_KEY"] = "dummy_key" + + +@pytest.mark.requires("reka") +def test_reka_model_name_param() -> None: + llm = ChatReka(model_name="reka-flash") + assert llm.model == "reka-flash" + + +@pytest.mark.requires("reka") +def test_reka_model_param() -> None: + llm = ChatReka(model="reka-flash") + assert llm.model == "reka-flash" + + +@pytest.mark.requires("reka") +def test_reka_model_kwargs() -> None: + llm = ChatReka(model_kwargs={"foo": "bar"}) + assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.mark.requires("reka") +def test_reka_invalid_model_kwargs() -> None: + with pytest.raises(ValueError): + ChatReka(model_kwargs={"max_tokens": "invalid"}) + + +@pytest.mark.requires("reka") +def test_reka_incorrect_field() -> None: + with pytest.warns(match="not default parameter"): + llm = ChatReka(foo="bar") + assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.mark.requires("reka") +def test_reka_initialization() -> None: + """Test Reka initialization.""" + # Verify that ChatReka can be initialized using a secret key provided + # as a parameter rather than an environment variable. + ChatReka(model="reka-flash", reka_api_key="test_key") + + +@pytest.mark.parametrize( + ("messages", "expected"), + [ + ([HumanMessage(content="Hello")], [{"role": "user", "content": "Hello"}]), + ( + [HumanMessage(content="Hello"), AIMessage(content="Hi there!")], + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + ), + ( + [ + SystemMessage(content="You're an assistant"), + HumanMessage(content="Hello"), + AIMessage(content="Hi there!"), + ], + [ + {"role": "user", "content": "You're an assistant\nHello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + ), + ], +) +def test_message_processing(messages: List[BaseMessage], expected: List[dict]) -> None: + result = process_messages_for_reka(messages) + assert result == expected + + +@pytest.mark.requires("reka") +def test_reka_streaming() -> None: + llm = ChatReka(streaming=True) + assert llm.streaming is True + + +@pytest.mark.requires("reka") +def test_reka_temperature() -> None: + llm = ChatReka(temperature=0.5) + assert llm.temperature == 0.5 + + +@pytest.mark.requires("reka") +def test_reka_max_tokens() -> None: + llm = ChatReka(max_tokens=100) + assert llm.max_tokens == 100 + + +@pytest.mark.requires("reka") +def test_reka_default_params() -> None: + llm = ChatReka() + assert llm._default_params == { + "max_tokens": 256, + "model": "reka-flash", + } + + +@pytest.mark.requires("reka") +def test_reka_identifying_params() -> None: + llm = ChatReka(temperature=0.7) + assert llm._identifying_params == { + "max_tokens": 256, + "model": "reka-flash", + "temperature": 0.7, + } + + +@pytest.mark.requires("reka") +def test_reka_llm_type() -> None: + llm = ChatReka() + assert llm._llm_type == "reka-chat" From e0637a5506dc1c6ee920e7feb4dc2d313fb92226 Mon Sep 17 00:00:00 2001 From: findalexli Date: Thu, 12 Sep 2024 14:29:14 -0700 Subject: [PATCH 04/36] Ruff linting check --- docs/docs/integrations/chat/reka.ipynb | 266 ++++++++++++++++++ .../langchain_community/chat_models/reka.py | 76 +++-- 2 files changed, 317 insertions(+), 25 deletions(-) create mode 100644 docs/docs/integrations/chat/reka.ipynb diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb new file mode 100644 index 0000000000000..eaf267252eceb --- /dev/null +++ b/docs/docs/integrations/chat/reka.ipynb @@ -0,0 +1,266 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models import ChatReka" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model = ChatReka()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "image_url = \"https://v0.docs.reka.ai/_images/000000245576.jpg\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "sidebar_label: Reka\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ChatReka\n", + "\n", + "This notebook provides a quick overview for getting started with Reka [chat models](/docs/concepts/#chat-models). \n", + "\n", + "Reka has several chat models. You can find information about their latest models and their costs, context windows, and supported input types in the [Reka docs](https://docs.reka.ai/available-models).\n", + "\n", + "\n", + "\n", + "\n", + "## Overview\n", + "### Integration details\n", + "\n", + "| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n", + "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", + "| [ChatReka] | [langchain_community](https://python.langchain.com/v0.2/api_reference/community/index.html) | ✅ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_community?style=flat-square&label=%20) |\n", + "\n", + "### Model features\n", + "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", + "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", + "| ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | \n", + "\n", + "## Setup\n", + "\n", + "To access Reka models you'll need to create an Reka developer account, get an API key, and install the `langchain_community` integration package and the reka python package via 'pip install reka-api'.\n", + "\n", + "### Credentials\n", + "\n", + "Head to https://platform.reka.ai/ to sign up for Reka and generate an API key. Once you've done this set the REKA_API_KEY environment variable:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "The LangChain __ModuleName__ integration lives in the `langchain_community` package:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qU langchain_community" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize a client" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "# os.environ[\"REKA_API_KEY\"] = getpass.getpass(\"Enter your Reka API key: \")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qU langchain_community" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models import ChatReka\n", + "# This \n", + "model = ChatReka()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Single turn text message" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=' Hello! How can I help you today? If you have any questions or need assistance with something, feel free to ask.\\n\\n', id='run-492dcda9-79b8-4e16-a81f-2f77b7d4e23a-0')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.invoke('hi')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Images input " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " The image you've uploaded is an indoor shot, and therefore, it does not provide any information about the weather outside. Since there are no windows or outdoor views depicted, it's not possible to determine the weather conditions at the time the photo was taken. If you need information about the weather, you should check a weather app, website, or news source based on your location or the location where the image was captured.\n" + ] + } + ], + "source": [ + "from langchain_community.chat_models import ChatReka\n", + "import httpx\n", + "\n", + "model = ChatReka()\n", + "image_url = \"https://v0.docs.reka.ai/_images/000000245576.jpg\"\n", + "from langchain_core.messages import HumanMessage\n", + "\n", + "message = HumanMessage(\n", + " content=[\n", + " {\"type\": \"text\", \"text\": \"describe the weather in this image\"},\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": image_url},\n", + " },\n", + " ],\n", + ")\n", + "response = model.invoke([message])\n", + "print(response.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multiple images as input" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " The two images are quite different in several ways. \n", + "\n", + "Firstly, the subjects of the images are different. The first image features a cat, while the second image features two German Shepherds, an adult and a puppy.\n", + "\n", + "Secondly, the actions depicted are different. In the first image, the cat is sitting still and looking directly at the camera. In contrast, the second image captures the two dogs in motion, running across a grassy field with a stick in the mouth of the adult dog.\n", + "\n", + "Lastly, the setting of the images is different. The first image has a neutral, blurred background, providing no specific context about the location. The second image, on the other hand, is set in an outdoor environment with a grassy field and trees in the background.\n", + "\n", + "These differences highlight the distinct characteristics and behaviors of the animals as well as the varying contexts in which they are photographed.\n" + ] + } + ], + "source": [ + "message = HumanMessage(\n", + " content=[\n", + " {\"type\": \"text\", \"text\": \"What are the difference between the two images? \"},\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": \"https://cdn.pixabay.com/photo/2019/07/23/13/51/shepherd-dog-4357790_1280.jpg\"},\n", + " },\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": \"https://cdn.pixabay.com/photo/2024/02/17/00/18/cat-8578562_1280.jpg\"},\n", + " },\n", + " ],\n", + ")\n", + "response = model.invoke([message])\n", + "print(response.content)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langchain_reka", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index 77d68d5185fba..40df5bf9492bd 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -1,4 +1,14 @@ -from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Union, +) from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -14,7 +24,6 @@ AIMessage, AIMessageChunk, BaseMessage, - ChatMessage, HumanMessage, SystemMessage, ) @@ -27,6 +36,7 @@ from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str try: + from reka import ChatMessage from reka.client import AsyncReka, Reka except ImportError: raise ValueError( @@ -42,20 +52,28 @@ DEFAULT_REKA_MODEL = "reka-flash" -def get_role(message: BaseMessage) -> str: - """Get the role of the message.""" - if isinstance(message, (ChatMessage, HumanMessage)): - return "user" - elif isinstance(message, AIMessage): - return "assistant" - elif isinstance(message, SystemMessage): - return "system" +def process_content_item(item: Dict[str, Any]) -> Dict[str, Any]: + """Process a single content item.""" + if item["type"] == "image_url": + image_url = item["image_url"] + if isinstance(image_url, dict) and "url" in image_url: + # If it's in LangChain format, extract the URL value + item["image_url"] = image_url["url"] + return item + + +def process_content(content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]: + """Process content to handle both text and media inputs, returning a list of content items.""" + if isinstance(content, str): + return [{"type": "text", "text": content}] + elif isinstance(content, list): + return [process_content_item(item) for item in content] else: - raise ValueError(f"Got unknown type {message}") + raise ValueError("Invalid content format") -def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str]]: - """Process messages for Reka format.""" +def convert_to_reka_messages(messages: List[Any]) -> List[ChatMessage]: + """Convert LangChain messages to Reka ChatMessage format.""" reka_messages = [] system_message = None @@ -65,12 +83,20 @@ def process_messages_for_reka(messages: List[BaseMessage]) -> List[Dict[str, str system_message = message.content else: raise ValueError("Multiple system messages are not supported.") - else: - content = message.content - if system_message and isinstance(message, HumanMessage): - content = f"{system_message}\n{content}" + elif isinstance(message, HumanMessage): + content = process_content(message.content) + if system_message: + if isinstance(content[0], dict) and content[0].get("type") == "text": + content[0]["text"] = f"{system_message}\n{content[0]['text']}" + else: + content.insert(0, {"type": "text", "text": system_message}) system_message = None - reka_messages.append({"role": get_role(message), "content": content}) + reka_messages.append(ChatMessage(content=content, role="user")) + elif isinstance(message, AIMessage): + content = process_content(message.content) + reka_messages.append(ChatMessage(content=content, role="assistant")) + else: + raise ValueError(f"Unsupported message type: {type(message)}") return reka_messages @@ -118,8 +144,6 @@ def validate_environment(cls, values: Dict) -> Dict: ) try: - from reka.client import AsyncReka, Reka - values["client"] = Reka( api_key=values["reka_api_key"].get_secret_value(), ) @@ -166,7 +190,7 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - reka_messages = process_messages_for_reka(messages) + reka_messages = convert_to_reka_messages(messages) params = {**self._default_params, **kwargs} if stop: params["stop"] = stop @@ -187,12 +211,14 @@ async def _astream( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - reka_messages = process_messages_for_reka(messages) + reka_messages = convert_to_reka_messages(messages) params = {**self._default_params, **kwargs} if stop: params["stop"] = stop - stream = self.async_client.chat.create_stream(messages=reka_messages, **params) + stream = await self.async_client.chat.create_stream( + messages=reka_messages, **params + ) async for chunk in stream: content = chunk.responses[0].chunk.content @@ -213,7 +239,7 @@ def _generate( self._stream(messages, stop=stop, run_manager=run_manager, **kwargs) ) - reka_messages = process_messages_for_reka(messages) + reka_messages = convert_to_reka_messages(messages) params = {**self._default_params, **kwargs} if stop: params["stop"] = stop @@ -234,7 +260,7 @@ async def _agenerate( self._astream(messages, stop=stop, run_manager=run_manager, **kwargs) ) - reka_messages = process_messages_for_reka(messages) + reka_messages = convert_to_reka_messages(messages) params = {**self._default_params, **kwargs} if stop: params["stop"] = stop From e93631747228b0a25a86a72c9c777d5934e96ac1 Mon Sep 17 00:00:00 2001 From: findalexli Date: Thu, 12 Sep 2024 15:03:59 -0700 Subject: [PATCH 05/36] Unit test update --- .../tests/unit_tests/chat_models/test_reka.py | 84 +++++++++++++++---- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 6c61d9a069fe4..eca1a948b8b22 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -1,4 +1,4 @@ -"""Test Reka Chat API wrapper.""" +"""Test Reka Chat wrapper.""" import os from typing import List @@ -7,7 +7,10 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_community.chat_models import ChatReka -from langchain_community.chat_models.reka import process_messages_for_reka +from langchain_community.chat_models.reka import ( + convert_to_reka_messages, + process_content, +) os.environ["REKA_API_KEY"] = "dummy_key" @@ -52,34 +55,87 @@ def test_reka_initialization() -> None: @pytest.mark.parametrize( - ("messages", "expected"), + ("content", "expected"), [ - ([HumanMessage(content="Hello")], [{"role": "user", "content": "Hello"}]), + ("Hello", [{"type": "text", "text": "Hello"}]), ( - [HumanMessage(content="Hello"), AIMessage(content="Hi there!")], [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": "https://example.com/image.jpg"}, + ], + [ + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": "https://example.com/image.jpg"}, ], ), ( [ - SystemMessage(content="You're an assistant"), - HumanMessage(content="Hello"), - AIMessage(content="Hi there!"), + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.jpg"}, + }, ], [ - {"role": "user", "content": "You're an assistant\nHello"}, - {"role": "assistant", "content": "Hi there!"}, + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": "https://example.com/image.jpg"}, ], ), ], ) -def test_message_processing(messages: List[BaseMessage], expected: List[dict]) -> None: - result = process_messages_for_reka(messages) +def test_process_content(content, expected) -> None: + result = process_content(content) assert result == expected +@pytest.mark.parametrize( + ("messages", "expected"), + [ + ( + [HumanMessage(content="Hello")], + [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + ), + ( + [ + HumanMessage( + content=[ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ] + ), + AIMessage(content="It's a beautiful landscape."), + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "It's a beautiful landscape."} + ], + }, + ], + ), + ], +) +def test_convert_to_reka_messages( + messages: List[BaseMessage], expected: List[dict] +) -> None: + result = convert_to_reka_messages(messages) + assert [message.dict() for message in result] == expected + + @pytest.mark.requires("reka") def test_reka_streaming() -> None: llm = ChatReka(streaming=True) From a7fe379c207f1723830aa4680240dbcdaff63c4f Mon Sep 17 00:00:00 2001 From: findalexli Date: Thu, 10 Oct 2024 16:39:55 -0700 Subject: [PATCH 06/36] Included dependecy --- libs/community/extended_testing_deps.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 3cdf3d52c632a..d8dc0362cb639 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -72,6 +72,7 @@ rapidfuzz>=3.1.1,<4 rapidocr-onnxruntime>=1.3.2,<2 rdflib==7.0.0 requests-toolbelt>=1.0.0,<2 +reka-api>=3.2.0 rspace_client>=2.5.0,<3 scikit-learn>=1.2.2,<2 simsimd>=5.0.0,<6 From cf807afaf79896d90c29556a58889063edd6f1b6 Mon Sep 17 00:00:00 2001 From: findalexli Date: Thu, 10 Oct 2024 17:25:18 -0700 Subject: [PATCH 07/36] Fixed linting --- .../community/langchain_community/chat_models/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 8c68842272358..790d5b461f46e 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -147,13 +147,13 @@ from langchain_community.chat_models.promptlayer_openai import ( PromptLayerChatOpenAI, ) - from langchain_community.chat_models.sambanova import ( - ChatSambaNovaCloud, - ChatSambaStudio - ) from langchain_community.chat_models.reka import ( ChatReka, ) + from langchain_community.chat_models.sambanova import ( + ChatSambaNovaCloud, + ChatSambaStudio, + ) from langchain_community.chat_models.snowflake import ( ChatSnowflakeCortex, ) From 6973ad0027683abd6a10f1558e23d3b9eaccbd9b Mon Sep 17 00:00:00 2001 From: findalexli Date: Mon, 14 Oct 2024 16:41:26 -0700 Subject: [PATCH 08/36] Track changes to too lintegraiton --- docs/docs/integrations/chat/reka.ipynb | 42 ++- .../langchain_community/chat_models/reka.py | 241 +++++++++++++++--- .../chat_models/test_reka.py | 109 +++++++- .../tests/unit_tests/chat_models/test_reka.py | 41 ++- .../chat_models/test_reka_search.py | 42 +++ 5 files changed, 407 insertions(+), 68 deletions(-) create mode 100644 libs/community/tests/unit_tests/chat_models/test_reka_search.py diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb index eaf267252eceb..71c83fa0b8b2e 100644 --- a/docs/docs/integrations/chat/reka.ipynb +++ b/docs/docs/integrations/chat/reka.ipynb @@ -81,9 +81,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "langchain-experimental 0.0.65 requires langchain-community<0.3.0,>=0.2.16, but you have langchain-community 0.3.2 which is incompatible.\n", + "langchain-experimental 0.0.65 requires langchain-core<0.3.0,>=0.2.38, but you have langchain-core 0.3.10 which is incompatible.\n", + "langchain-openai 0.1.23 requires langchain-core<0.3.0,>=0.2.35, but you have langchain-core 0.3.10 which is incompatible.\n", + "langchain-standard-tests 0.1.1 requires langchain-core<0.3,>=0.1.40, but you have langchain-core 0.3.10 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "%pip install -qU langchain_community" ] @@ -142,7 +155,7 @@ { "data": { "text/plain": [ - "AIMessage(content=' Hello! How can I help you today? If you have any questions or need assistance with something, feel free to ask.\\n\\n', id='run-492dcda9-79b8-4e16-a81f-2f77b7d4e23a-0')" + "AIMessage(content=' Hello! How can I help you today? If you have a question, need assistance, or just want to chat, feel free to let me know. Have a great day!\\n\\n', additional_kwargs={}, response_metadata={}, id='run-206ce7e0-7c6c-4d81-b66b-8b98cb1232cc-0')" ] }, "execution_count": 6, @@ -163,14 +176,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " The image you've uploaded is an indoor shot, and therefore, it does not provide any information about the weather outside. Since there are no windows or outdoor views depicted, it's not possible to determine the weather conditions at the time the photo was taken. If you need information about the weather, you should check a weather app, website, or news source based on your location or the location where the image was captured.\n" + " The image shows an indoor setting with no visible weather conditions. The focus is on a ginger cat inspecting a computer keyboard. There are no windows or natural light sources that would provide information about the weather outside. The environment is a typical home office setup with a desk, computer, and some other items like a pen holder and a mobile phone.\n" ] } ], @@ -204,22 +217,18 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " The two images are quite different in several ways. \n", - "\n", - "Firstly, the subjects of the images are different. The first image features a cat, while the second image features two German Shepherds, an adult and a puppy.\n", + " The first image shows two German Shepherd dogs, one adult and one puppy, running through grass. The adult dog is carrying a large stick in its mouth, suggesting play or exercise, while the puppy follows close behind. Both are in a dynamic, natural outdoor setting with lush greenery.\n", "\n", - "Secondly, the actions depicted are different. In the first image, the cat is sitting still and looking directly at the camera. In contrast, the second image captures the two dogs in motion, running across a grassy field with a stick in the mouth of the adult dog.\n", + "The second image features a close-up of a single, adult Siamese cat with striking blue eyes, sitting in a natural setting that appears to be outdoors with dried leaves or grass around. The cat's expression is calm and focused, and its fur is predominantly light-colored with darker points on its ears, face, and tail.\n", "\n", - "Lastly, the setting of the images is different. The first image has a neutral, blurred background, providing no specific context about the location. The second image, on the other hand, is set in an outdoor environment with a grassy field and trees in the background.\n", - "\n", - "These differences highlight the distinct characteristics and behaviors of the animals as well as the varying contexts in which they are photographed.\n" + "The key differences between the images are the subjects (dogs vs. cat) and their expressions (playful vs. calm and focused). Additionally, the settings are similar yet distinct, with both animals in natural environments but the first image depicting more active engagement with the surroundings, while the second is more still and serene.\n" ] } ], @@ -240,6 +249,13 @@ "response = model.invoke([message])\n", "print(response.content)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index 40df5bf9492bd..5fec7f067acdd 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -1,3 +1,4 @@ +import json from typing import ( Any, AsyncIterator, @@ -5,8 +6,11 @@ Dict, Iterator, List, + Literal, Mapping, Optional, + Sequence, + Type, Union, ) @@ -14,7 +18,7 @@ AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain_core.language_models import BaseLanguageModel +from langchain_core.language_models import BaseLanguageModel, LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, @@ -26,17 +30,21 @@ BaseMessage, HumanMessage, SystemMessage, + ToolMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from langchain_core.utils import ( get_from_dict_or_env, get_pydantic_field_names, ) +from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str +from pydantic import Field, SecretStr, model_validator try: - from reka import ChatMessage + from reka import ChatMessage, ToolCall from reka.client import AsyncReka, Reka except ImportError: raise ValueError( @@ -63,7 +71,8 @@ def process_content_item(item: Dict[str, Any]) -> Dict[str, Any]: def process_content(content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]: - """Process content to handle both text and media inputs, returning a list of content items.""" + """Process content to handle both text and media inputs, + Returning a list of content items.""" if isinstance(content, str): return [{"type": "text", "text": content}] elif isinstance(content, list): @@ -72,10 +81,10 @@ def process_content(content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, raise ValueError("Invalid content format") -def convert_to_reka_messages(messages: List[Any]) -> List[ChatMessage]: - """Convert LangChain messages to Reka ChatMessage format.""" +def convert_to_reka_messages(messages: List[Any]) -> List[Dict[str, Any]]: + """Convert LangChain messages to Reka message format.""" reka_messages = [] - system_message = None + system_message = None # Double check on the system message for message in messages: if isinstance(message, SystemMessage): @@ -91,10 +100,39 @@ def convert_to_reka_messages(messages: List[Any]) -> List[ChatMessage]: else: content.insert(0, {"type": "text", "text": system_message}) system_message = None - reka_messages.append(ChatMessage(content=content, role="user")) + reka_messages.append({"role": "user", "content": content}) elif isinstance(message, AIMessage): + reka_message = {"role": "assistant"} + if message.content: + reka_message["content"] = process_content(message.content) + + if "tool_calls" in message.additional_kwargs: + tool_calls = message.additional_kwargs["tool_calls"] + formatted_tool_calls = [] + for tool_call in tool_calls: + formatted_tool_call = ToolCall( + id=tool_call["id"], + name=tool_call["function"]["name"], + parameters=json.loads(tool_call["function"]["arguments"]), + ) + formatted_tool_calls.append(formatted_tool_call) + reka_message["tool_calls"] = formatted_tool_calls + reka_messages.append(reka_message) + elif isinstance(message, ToolMessage): + reka_messages.append( + { + "role": "tool_output", + "content": [ + { + "tool_call_id": message.tool_call_id, + "output": json.dumps({"status": message.content}), + } + ], + } + ) + elif isinstance(message, ChatMessage): content = process_content(message.content) - reka_messages.append(ChatMessage(content=content, role="assistant")) + reka_messages.append({"role": message.role, "content": content}) else: raise ValueError(f"Unsupported message type: {type(message)}") @@ -105,29 +143,16 @@ class RekaCommon(BaseLanguageModel): client: Any = None #: :meta private: async_client: Any = None #: :meta private: model: str = Field(default=DEFAULT_REKA_MODEL, alias="model_name") - """Model name to use.""" - max_tokens: int = Field(default=256) - """Denotes the number of tokens to predict per generation.""" - temperature: Optional[float] = None - """A non-negative float that tunes the degree of randomness in generation.""" - streaming: bool = False - """Whether to stream the results.""" - default_request_timeout: Optional[float] = None - """Timeout for requests to Reka Completion API. Default is 600 seconds.""" - max_retries: int = 2 - """Number of retries allowed for requests sent to the Reka Completion API.""" - reka_api_key: Optional[SecretStr] = None - count_tokens: Optional[Callable[[str], int]] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) - @root_validator(pre=True) + @model_validator(mode="before") def build_extra(cls, values: Dict) -> Dict: extra = values.get("model_kwargs", {}) all_required_field_names = get_pydantic_field_names(cls) @@ -136,27 +161,27 @@ def build_extra(cls, values: Dict) -> Dict: ) return values - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - values["reka_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "reka_api_key", "REKA_API_KEY") + @model_validator(mode="after") + def validate_environment(cls, self: "RekaCommon") -> "RekaCommon": + """Validate that API key and Python package exist in the environment.""" + self.reka_api_key = convert_to_secret_str( + get_from_dict_or_env(self, "reka_api_key", "REKA_API_KEY") ) try: - values["client"] = Reka( - api_key=values["reka_api_key"].get_secret_value(), + self.client = Reka( + api_key=self.reka_api_key.get_secret_value(), ) - values["async_client"] = AsyncReka( - api_key=values["reka_api_key"].get_secret_value(), + self.async_client = AsyncReka( + api_key=self.reka_api_key.get_secret_value(), ) except ImportError: raise ImportError( - "Could not import reka python package. " + "Could not import Reka Python package. " "Please install it with `pip install reka-api`." ) - return values + return self @property def _default_params(self) -> Mapping[str, Any]: @@ -183,6 +208,18 @@ def _llm_type(self) -> str: """Return type of chat model.""" return "reka-chat" + @property + def _default_params(self) -> Mapping[str, Any]: + """Get the default parameters for calling Reka API.""" + d = { + "max_tokens": self.max_tokens, + "model": self.model, + } + if self.temperature is not None: + d["temperature"] = self.temperature + + return {**d, **self.model_kwargs} + def _stream( self, messages: List[BaseMessage], @@ -216,9 +253,7 @@ async def _astream( if stop: params["stop"] = stop - stream = await self.async_client.chat.create_stream( - messages=reka_messages, **params - ) + stream = self.async_client.chat.create_stream(messages=reka_messages, **params) async for chunk in stream: content = chunk.responses[0].chunk.content @@ -245,7 +280,29 @@ def _generate( params["stop"] = stop response = self.client.chat.create(messages=reka_messages, **params) - message = AIMessage(content=response.responses[0].message.content) + if response.responses[0].message.tool_calls: + tool_calls = response.responses[0].message.tool_calls + message = AIMessage( + content="", # Empty string instead of None + additional_kwargs={ + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.parameters), + }, + } + for tc in tool_calls + ] + }, + ) + else: + content = response.responses[0].message.content + # Ensure content is never None + message = AIMessage(content=content if content is not None else "") + return ChatResult(generations=[ChatGeneration(message=message)]) async def _agenerate( @@ -266,7 +323,29 @@ async def _agenerate( params["stop"] = stop response = await self.async_client.chat.create(messages=reka_messages, **params) - message = AIMessage(content=response.responses[0].message.content) + if response.responses[0].message.tool_calls: + tool_calls = response.responses[0].message.tool_calls + message = AIMessage( + content="", # Empty string instead of None + additional_kwargs={ + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.parameters), + }, + } + for tc in tool_calls + ] + }, + ) + else: + content = response.responses[0].message.content + # Ensure content is never None + message = AIMessage(content=content if content is not None else "") + return ChatResult(generations=[ChatGeneration(message=message)]) def get_num_tokens(self, text: str) -> int: @@ -276,3 +355,87 @@ def get_num_tokens(self, text: str) -> int: "get_num_tokens() is not implemented for Reka models." ) return self.count_tokens(text) + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], + *, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ] = "auto", + strict: Optional[bool] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Assumes model is compatible with OpenAI tool-calling API. + + Args: + tools: A list of tool definitions to bind to this chat model. + Supports any tool definition handled by + :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. + tool_choice: Which tool to require the model to call. Options are: + + - str of the form ``"<>"``: calls <> tool. + - ``"auto"``: automatically selects a tool (including no tool). + - ``"none"``: does not call a tool. + - ``"any"`` or ``"required"`` or ``True``: force at least one tool to be called. + - dict of the form ``{"type": "function", "function": {"name": <>}}``: calls <> tool. + - ``False`` or ``None``: no effect, default OpenAI behavior. + strict: If True, model output is guaranteed to exactly match the JSON Schema + provided in the tool definition. If True, the input schema will be + validated according to + https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. + If False, input schema will not be validated and model output will not + be validated. + If None, ``strict`` argument will not be passed to the model. + kwargs: Any additional parameters are passed directly to + :meth:`~langchain_openai.chat_models.base.ChatOpenAI.bind`. + + .. versionchanged:: 0.1.21 + + Support for ``strict`` argument added. + + """ # noqa: E501 + + formatted_tools = [ + convert_to_openai_tool(tool, strict=strict) for tool in tools + ] + if tool_choice: + if isinstance(tool_choice, str): + # tool_choice is a tool/function name + if tool_choice not in ("auto", "none", "any", "required"): + tool_choice = { + "type": "function", + "function": {"name": tool_choice}, + } + # 'any' is not natively supported by OpenAI API. + # We support 'any' since other models use this instead of 'required'. + if tool_choice == "any": + tool_choice = "required" + elif isinstance(tool_choice, bool): + tool_choice = "required" + elif isinstance(tool_choice, dict): + tool_names = [ + formatted_tool["function"]["name"] + for formatted_tool in formatted_tools + ] + if not any( + tool_name == tool_choice["function"]["name"] + for tool_name in tool_names + ): + raise ValueError( + f"Tool choice {tool_choice} was specified, but the only " + f"provided tools were {tool_names}." + ) + else: + raise ValueError( + f"Unrecognized tool_choice type. Expected str, bool or dict. " + f"Received: {tool_choice}" + ) + kwargs["tool_choice"] = tool_choice + # Formatting hack TODO + formatted_tools = [ + formatted_tool["function"] for formatted_tool in formatted_tools + ] + return super().bind(tools=formatted_tools, **kwargs) diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index 7b599cd7c2b48..d1aacc252b2a9 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -1,30 +1,39 @@ """Test Reka API wrapper.""" +import logging from typing import List import pytest -from langchain_core.callbacks import CallbackManager -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + ToolMessage, +) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_community.chat_models.reka import ChatReka from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + @pytest.mark.scheduled def test_reka_call() -> None: - """Test valid call to Reka.""" - chat = ChatReka(model="reka-flash") + """Test a simple call to Reka.""" + chat = ChatReka(model="reka-flash", verbose=True) message = HumanMessage(content="Hello") response = chat.invoke([message]) assert isinstance(response, AIMessage) assert isinstance(response.content, str) + logger.debug(f"Response content: {response.content}") @pytest.mark.scheduled def test_reka_generate() -> None: - """Test generate method of Reka.""" - chat = ChatReka(model="reka-flash") + """Test the generate method of Reka.""" + chat = ChatReka(model="reka-flash", verbose=True) chat_messages: List[List[BaseMessage]] = [ [HumanMessage(content="How many toes do dogs have?")] ] @@ -35,44 +44,45 @@ def test_reka_generate() -> None: assert isinstance(response, ChatGeneration) assert isinstance(response.text, str) assert response.text == response.message.content + logger.debug(f"Generated response: {response.text}") assert chat_messages == messages_copy @pytest.mark.scheduled def test_reka_streaming() -> None: """Test streaming tokens from Reka.""" - chat = ChatReka(model="reka-flash", streaming=True) - message = HumanMessage(content="Hello") + chat = ChatReka(model="reka-flash", streaming=True, verbose=True) + message = HumanMessage(content="Tell me a story.") response = chat.invoke([message]) assert isinstance(response, AIMessage) assert isinstance(response.content, str) + logger.debug(f"Streaming response content: {response.content}") @pytest.mark.scheduled def test_reka_streaming_callback() -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" + """Test that streaming correctly invokes callbacks.""" callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) chat = ChatReka( model="reka-flash", streaming=True, - callback_manager=callback_manager, + callbacks=[callback_handler], verbose=True, ) message = HumanMessage(content="Write me a sentence with 10 words.") chat.invoke([message]) assert callback_handler.llm_streams > 1 + logger.debug(f"Number of LLM streams: {callback_handler.llm_streams}") @pytest.mark.scheduled async def test_reka_async_streaming_callback() -> None: - """Test that streaming correctly invokes on_llm_new_token callback.""" + """Test asynchronous streaming with callbacks.""" callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) chat = ChatReka( model="reka-flash", streaming=True, - callback_manager=callback_manager, + callbacks=[callback_handler], verbose=True, ) chat_messages: List[BaseMessage] = [ @@ -85,3 +95,74 @@ async def test_reka_async_streaming_callback() -> None: assert isinstance(response, ChatGeneration) assert isinstance(response.text, str) assert response.text == response.message.content + logger.debug(f"Async generated response: {response.text}") + + +@pytest.mark.scheduled +def test_reka_tool_usage_integration() -> None: + """Test tool usage with Reka API integration.""" + # Initialize the ChatReka model with tools and verbose logging + chat_reka = ChatReka(model="reka-flash", verbose=True) + tools = [ + { + "type": "function", + "function": { + "name": "get_product_availability", + "description": ( + "Determine whether a product is currently in stock given " + "a product ID." + ), + "parameters": { + "type": "object", + "properties": { + "product_id": { + "type": "string", + "description": ( + "The unique product ID to check availability for" + ), + }, + }, + "required": ["product_id"], + }, + }, + }, + ] + chat_reka_with_tools = chat_reka.bind_tools(tools) + + # Start a conversation + messages: List[BaseMessage] = [ + HumanMessage(content="Is product A12345 in stock right now?") + ] + + # Get the initial response + response = chat_reka_with_tools.invoke(messages) + assert isinstance(response, AIMessage) + logger.debug(f"Initial AI message: {response.content}") + + # Check if the model wants to use a tool + if "tool_calls" in response.additional_kwargs: + tool_calls = response.additional_kwargs["tool_calls"] + for tool_call in tool_calls: + function_name = tool_call["function"]["name"] + arguments = tool_call["function"]["arguments"] + logger.debug( + f"Tool call requested: {function_name} with arguments {arguments}" + ) + + # Simulate executing the tool + tool_output = "AVAILABLE" + + tool_message = ToolMessage( + content=tool_output, tool_call_id=tool_call["id"] + ) + messages.append(response) + messages.append(tool_message) + + final_response = chat_reka_with_tools.invoke(messages) + assert isinstance(final_response, AIMessage) + logger.debug(f"Final AI message: {final_response.content}") + + # Assert that the response message is non-empty + assert final_response.content, "The final response content is empty." + else: + pytest.fail("The model did not request a tool.") diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index eca1a948b8b22..7e222cda65156 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -1,10 +1,12 @@ """Test Reka Chat wrapper.""" +import json import os from typing import List +from unittest.mock import MagicMock, patch import pytest -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_community.chat_models import ChatReka from langchain_community.chat_models.reka import ( @@ -133,7 +135,7 @@ def test_convert_to_reka_messages( messages: List[BaseMessage], expected: List[dict] ) -> None: result = convert_to_reka_messages(messages) - assert [message.dict() for message in result] == expected + assert result == expected @pytest.mark.requires("reka") @@ -177,3 +179,38 @@ def test_reka_identifying_params() -> None: def test_reka_llm_type() -> None: llm = ChatReka() assert llm._llm_type == "reka-chat" + + +@pytest.mark.requires("reka") +def test_reka_tool_use_with_mocked_response() -> None: + with patch("langchain_community.chat_models.reka.Reka") as MockReka: + # Mock the Reka client + mock_client = MockReka.return_value + mock_chat = MagicMock() + mock_client.chat = mock_chat + mock_response = MagicMock() + mock_message = MagicMock() + mock_tool_call = MagicMock() + mock_tool_call.id = "tool_call_1" + mock_tool_call.name = "search_tool" + mock_tool_call.parameters = {"query": "LangChain"} + mock_message.tool_calls = [mock_tool_call] + mock_message.content = None + mock_response.responses = [MagicMock(message=mock_message)] + mock_chat.create.return_value = mock_response + + llm = ChatReka() + messages = [HumanMessage(content="Tell me about LangChain")] + result = llm._generate(messages) + + assert len(result.generations) == 1 + ai_message = result.generations[0].message + assert ai_message.content == "" + assert "tool_calls" in ai_message.additional_kwargs + tool_calls = ai_message.additional_kwargs["tool_calls"] + assert len(tool_calls) == 1 + assert tool_calls[0]["id"] == "tool_call_1" + assert tool_calls[0]["function"]["name"] == "search_tool" + assert tool_calls[0]["function"]["arguments"] == json.dumps( + {"query": "LangChain"} + ) diff --git a/libs/community/tests/unit_tests/chat_models/test_reka_search.py b/libs/community/tests/unit_tests/chat_models/test_reka_search.py new file mode 100644 index 0000000000000..ed3028cdc1230 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_reka_search.py @@ -0,0 +1,42 @@ +from dotenv import load_dotenv +from langchain import hub +from langchain.agents import AgentExecutor, create_self_ask_with_search_agent +from langchain.chat_models import ChatReka +from langchain_core.tools import Tool + +from langchain_community.utilities import GoogleSerperAPIWrapper + +# Set up API keys +# os.environ["SERPER_API_KEY"] = "your_serper_api_key_here" +load_dotenv() + +# Initialize ChatReka +chat_reka = ChatReka( + model="reka-core", + temperature=0.4, +) +prompt = hub.pull("hwchase17/self-ask-with-search") +# Initialize Google Serper API Wrapper +search = GoogleSerperAPIWrapper() + +# Define tools +tools = [ + Tool( + name="Intermediate Answer", + func=search.run, + description=""" + useful for when you need to ask with search. + """, + ) +] + +# Initialize the agent +agent = create_self_ask_with_search_agent(chat_reka, tools, prompt) + +agent_executor = AgentExecutor(agent=agent, tools=tools) + + +# Example usage +if __name__ == "__main__": + query = "What is the hometown of the reigning men's U.S. Open champion?" + agent_executor.invoke({"input": "query"}) From 64adf02bcf0a3983f2917276bffab13885b700f1 Mon Sep 17 00:00:00 2001 From: findalexli Date: Tue, 15 Oct 2024 15:01:11 -0700 Subject: [PATCH 09/36] Fix linting --- docs/docs/integrations/chat/reka.ipynb | 52 +++- .../langchain_community/chat_models/reka.py | 254 +++++++++--------- .../tests/unit_tests/chat_models/test_reka.py | 46 ++-- .../chat_models/test_reka_search.py | 42 --- 4 files changed, 190 insertions(+), 204 deletions(-) delete mode 100644 libs/community/tests/unit_tests/chat_models/test_reka_search.py diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb index 71c83fa0b8b2e..5528c8476dd7e 100644 --- a/docs/docs/integrations/chat/reka.ipynb +++ b/docs/docs/integrations/chat/reka.ipynb @@ -250,12 +250,62 @@ "print(response.content)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tool call example with self-ask-with-search agent" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from dotenv import load_dotenv\n", + "from langchain import hub\n", + "from langchain.agents import AgentExecutor, create_self_ask_with_search_agent\n", + "from langchain.chat_models import ChatReka\n", + "from langchain_core.tools import Tool\n", + "\n", + "from langchain_community.utilities import GoogleSerperAPIWrapper\n", + "\n", + "# Set up API keys\n", + "# os.environ[\"SERPER_API_KEY\"] = \"your_serper_api_key_here\"\n", + "load_dotenv()\n", + "\n", + "# Initialize ChatReka\n", + "chat_reka = ChatReka(\n", + " model=\"reka-core\",\n", + " temperature=0.4,\n", + ")\n", + "prompt = hub.pull(\"hwchase17/self-ask-with-search\")\n", + "# Initialize Google Serper API Wrapper\n", + "search = GoogleSerperAPIWrapper()\n", + "\n", + "# Define tools\n", + "tools = [\n", + " Tool(\n", + " name=\"Intermediate Answer\",\n", + " func=search.run,\n", + " description=\"\"\"\n", + " useful for when you need to ask with search. \n", + " \"\"\",\n", + " )\n", + "]\n", + "\n", + "# Initialize the agent\n", + "agent = create_self_ask_with_search_agent(chat_reka, tools, prompt)\n", + "\n", + "agent_executor = AgentExecutor(agent=agent, tools=tools)\n", + "\n", + "\n", + "# Example usage\n", + "if __name__ == \"__main__\":\n", + " query = \"What is the hometown of the reigning men's U.S. Open champion?\"\n", + " agent_executor.invoke({\"input\": \"query\"})" + ] } ], "metadata": { diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index 5fec7f067acdd..ece04ac5fd4f3 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -37,14 +37,13 @@ from langchain_core.tools import BaseTool from langchain_core.utils import ( get_from_dict_or_env, - get_pydantic_field_names, ) from langchain_core.utils.function_calling import convert_to_openai_tool -from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str -from pydantic import Field, SecretStr, model_validator +from pydantic import ConfigDict, Field, model_validator try: - from reka import ChatMessage, ToolCall + from reka import ChatMessage as RekaChatMessage + from reka import ToolCall from reka.client import AsyncReka, Reka except ImportError: raise ValueError( @@ -59,6 +58,8 @@ DEFAULT_REKA_MODEL = "reka-flash" +ContentType = Union[str, List[Union[str, Dict[str, Any]]]] + def process_content_item(item: Dict[str, Any]) -> Dict[str, Any]: """Process a single content item.""" @@ -70,42 +71,62 @@ def process_content_item(item: Dict[str, Any]) -> Dict[str, Any]: return item -def process_content(content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]: +def process_content(content: ContentType) -> List[Dict[str, Any]]: """Process content to handle both text and media inputs, - Returning a list of content items.""" + returning a list of content items.""" if isinstance(content, str): return [{"type": "text", "text": content}] elif isinstance(content, list): - return [process_content_item(item) for item in content] + result = [] + for item in content: + if isinstance(item, str): + result.append({"type": "text", "text": item}) + elif isinstance(item, dict): + result.append(process_content_item(item)) + else: + raise ValueError(f"Invalid content item format: {item}") + return result else: raise ValueError("Invalid content format") -def convert_to_reka_messages(messages: List[Any]) -> List[Dict[str, Any]]: +def convert_to_reka_messages(messages: List[BaseMessage]) -> List[Dict[str, Any]]: """Convert LangChain messages to Reka message format.""" - reka_messages = [] - system_message = None # Double check on the system message + reka_messages: List[Dict[str, Any]] = [] + system_message: Optional[str] = None for message in messages: if isinstance(message, SystemMessage): if system_message is None: - system_message = message.content + if isinstance(message.content, str): + system_message = message.content + else: + raise TypeError("SystemMessage content must be a string.") else: raise ValueError("Multiple system messages are not supported.") elif isinstance(message, HumanMessage): - content = process_content(message.content) + processed_content = process_content(message.content) if system_message: - if isinstance(content[0], dict) and content[0].get("type") == "text": - content[0]["text"] = f"{system_message}\n{content[0]['text']}" + if ( + processed_content + and isinstance(processed_content[0], dict) + and processed_content[0].get("type") == "text" + and "text" in processed_content[0] + ): + processed_content[0]["text"] = ( + f"{system_message}\n{processed_content[0]['text']}" + ) else: - content.insert(0, {"type": "text", "text": system_message}) + processed_content.insert( + 0, {"type": "text", "text": system_message} + ) system_message = None - reka_messages.append({"role": "user", "content": content}) + reka_messages.append({"role": "user", "content": processed_content}) elif isinstance(message, AIMessage): - reka_message = {"role": "assistant"} + reka_message: Dict[str, Any] = {"role": "assistant"} if message.content: - reka_message["content"] = process_content(message.content) - + processed_content = process_content(message.content) + reka_message["content"] = processed_content if "tool_calls" in message.additional_kwargs: tool_calls = message.additional_kwargs["tool_calls"] formatted_tool_calls = [] @@ -119,20 +140,22 @@ def convert_to_reka_messages(messages: List[Any]) -> List[Dict[str, Any]]: reka_message["tool_calls"] = formatted_tool_calls reka_messages.append(reka_message) elif isinstance(message, ToolMessage): + content_list: List[Dict[str, Any]] = [] + content_list.append( + { + "tool_call_id": message.tool_call_id, + "output": json.dumps({"status": message.content}), + } + ) reka_messages.append( { "role": "tool_output", - "content": [ - { - "tool_call_id": message.tool_call_id, - "output": json.dumps({"status": message.content}), - } - ], + "content": content_list, } ) - elif isinstance(message, ChatMessage): - content = process_content(message.content) - reka_messages.append({"role": message.role, "content": content}) + elif isinstance(message, RekaChatMessage): + processed_content = process_content(message.content) + reka_messages.append({"role": message.role, "content": processed_content}) else: raise ValueError(f"Unsupported message type: {type(message)}") @@ -142,84 +165,78 @@ def convert_to_reka_messages(messages: List[Any]) -> List[Dict[str, Any]]: class RekaCommon(BaseLanguageModel): client: Any = None #: :meta private: async_client: Any = None #: :meta private: - model: str = Field(default=DEFAULT_REKA_MODEL, alias="model_name") + model: str = Field(default=DEFAULT_REKA_MODEL) max_tokens: int = Field(default=256) temperature: Optional[float] = None streaming: bool = False default_request_timeout: Optional[float] = None max_retries: int = 2 - reka_api_key: Optional[SecretStr] = None + reka_api_key: Optional[str] = None count_tokens: Optional[Callable[[str], int]] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) - @model_validator(mode="before") - def build_extra(cls, values: Dict) -> Dict: - extra = values.get("model_kwargs", {}) - all_required_field_names = get_pydantic_field_names(cls) - values["model_kwargs"] = build_extra_kwargs( - extra, values, all_required_field_names - ) - return values + # @model_validator(mode="before") + # @classmethod + # def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + # extra = values.get("model_kwargs", {}) + # all_field_names = set(cls.model_fields) + # for field in list(values.keys()): + # if field not in all_field_names: + # extra[field] = values.pop(field) + # values["model_kwargs"] = extra + # return values - @model_validator(mode="after") - def validate_environment(cls, self: "RekaCommon") -> "RekaCommon": + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate that API key and Python package exist in the environment.""" - self.reka_api_key = convert_to_secret_str( - get_from_dict_or_env(self, "reka_api_key", "REKA_API_KEY") + reka_api_key = values.get("reka_api_key") + reka_api_key = get_from_dict_or_env( + {"reka_api_key": reka_api_key}, "reka_api_key", "REKA_API_KEY" ) + values["reka_api_key"] = reka_api_key try: - self.client = Reka( - api_key=self.reka_api_key.get_secret_value(), + values["client"] = Reka( + api_key=reka_api_key, ) - self.async_client = AsyncReka( - api_key=self.reka_api_key.get_secret_value(), + values["async_client"] = AsyncReka( + api_key=reka_api_key, ) - except ImportError: raise ImportError( "Could not import Reka Python package. " "Please install it with `pip install reka-api`." ) - return self + return values @property def _default_params(self) -> Mapping[str, Any]: """Get the default parameters for calling Reka API.""" - d = { - "max_tokens": self.max_tokens, + params = { "model": self.model, + "max_tokens": self.max_tokens, } if self.temperature is not None: - d["temperature"] = self.temperature - return {**d, **self.model_kwargs} - - @property - def _identifying_params(self) -> Mapping[str, Any]: - """Get the identifying parameters.""" - return {**{}, **self._default_params} + params["temperature"] = self.temperature + return {**params, **self.model_kwargs} class ChatReka(BaseChatModel, RekaCommon): """Reka chat large language models.""" + model: str = Field(default=DEFAULT_REKA_MODEL) + reka_api_key: Optional[str] = None + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + temperature: Optional[float] = None + max_tokens: int = Field(default=256) + model_config = ConfigDict(extra="forbid") + @property def _llm_type(self) -> str: """Return type of chat model.""" return "reka-chat" - @property - def _default_params(self) -> Mapping[str, Any]: - """Get the default parameters for calling Reka API.""" - d = { - "max_tokens": self.max_tokens, - "model": self.model, - } - if self.temperature is not None: - d["temperature"] = self.temperature - - return {**d, **self.model_kwargs} - def _stream( self, messages: List[BaseMessage], @@ -360,82 +377,51 @@ def bind_tools( self, tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], *, - tool_choice: Optional[ - Union[dict, str, Literal["auto", "none", "required", "any"], bool] - ] = "auto", + tool_choice: Literal["auto", "none", "tool"] = "auto", strict: Optional[bool] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. - Assumes model is compatible with OpenAI tool-calling API. + The `tool_choice` parameter controls how the model uses the tools you pass. + There are three available options: + + - `"auto"`: Lets the model decide whether or not to invoke a tool. This is the + recommended way to do function calling with our models. + - `"none"`: Disables tool calling. In this case, even if you pass tools to + the model, the model will not invoke any tools. + - `"tool"`: Forces the model to invoke one or more of the tools it has + been passed. Args: tools: A list of tool definitions to bind to this chat model. Supports any tool definition handled by :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. - tool_choice: Which tool to require the model to call. Options are: - - - str of the form ``"<>"``: calls <> tool. - - ``"auto"``: automatically selects a tool (including no tool). - - ``"none"``: does not call a tool. - - ``"any"`` or ``"required"`` or ``True``: force at least one tool to be called. - - dict of the form ``{"type": "function", "function": {"name": <>}}``: calls <> tool. - - ``False`` or ``None``: no effect, default OpenAI behavior. - strict: If True, model output is guaranteed to exactly match the JSON Schema - provided in the tool definition. If True, the input schema will be - validated according to - https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. - If False, input schema will not be validated and model output will not - be validated. - If None, ``strict`` argument will not be passed to the model. - kwargs: Any additional parameters are passed directly to - :meth:`~langchain_openai.chat_models.base.ChatOpenAI.bind`. - - .. versionchanged:: 0.1.21 - - Support for ``strict`` argument added. - - """ # noqa: E501 - + tool_choice: Controls how the model uses the tools you pass. + Options are "auto", "none", or "tool". Defaults to "auto". + strict: + If True, model output is guaranteed to exactly match the JSON Schema + provided in the tool definition. + If False, input schema will not be validated + and model output will not be validated. + If None, ``strict`` argument will not + be passed to the model. + kwargs: Any additional parameters are passed directly to the model. + + Returns: + Runnable: An executable chain or component. + """ formatted_tools = [ convert_to_openai_tool(tool, strict=strict) for tool in tools ] - if tool_choice: - if isinstance(tool_choice, str): - # tool_choice is a tool/function name - if tool_choice not in ("auto", "none", "any", "required"): - tool_choice = { - "type": "function", - "function": {"name": tool_choice}, - } - # 'any' is not natively supported by OpenAI API. - # We support 'any' since other models use this instead of 'required'. - if tool_choice == "any": - tool_choice = "required" - elif isinstance(tool_choice, bool): - tool_choice = "required" - elif isinstance(tool_choice, dict): - tool_names = [ - formatted_tool["function"]["name"] - for formatted_tool in formatted_tools - ] - if not any( - tool_name == tool_choice["function"]["name"] - for tool_name in tool_names - ): - raise ValueError( - f"Tool choice {tool_choice} was specified, but the only " - f"provided tools were {tool_names}." - ) - else: - raise ValueError( - f"Unrecognized tool_choice type. Expected str, bool or dict. " - f"Received: {tool_choice}" - ) - kwargs["tool_choice"] = tool_choice - # Formatting hack TODO - formatted_tools = [ - formatted_tool["function"] for formatted_tool in formatted_tools - ] + + # Ensure tool_choice is one of the allowed options + if tool_choice not in ("auto", "none", "tool"): + raise ValueError("Must be one of 'auto', 'none', or 'tool'.") + + # Map tool_choice to the parameter expected by the Reka API + kwargs["tool_choice"] = tool_choice + + # Pass the tools and updated kwargs to the model + formatted_tools = [tool["function"] for tool in formatted_tools] return super().bind(tools=formatted_tools, **kwargs) diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 7e222cda65156..ba0b2665ab359 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -1,12 +1,11 @@ -"""Test Reka Chat wrapper.""" - import json import os -from typing import List +from typing import Any, Dict, List from unittest.mock import MagicMock, patch import pytest from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from pydantic import ValidationError from langchain_community.chat_models import ChatReka from langchain_community.chat_models.reka import ( @@ -17,12 +16,6 @@ os.environ["REKA_API_KEY"] = "dummy_key" -@pytest.mark.requires("reka") -def test_reka_model_name_param() -> None: - llm = ChatReka(model_name="reka-flash") - assert llm.model == "reka-flash" - - @pytest.mark.requires("reka") def test_reka_model_param() -> None: llm = ChatReka(model="reka-flash") @@ -35,17 +28,11 @@ def test_reka_model_kwargs() -> None: assert llm.model_kwargs == {"foo": "bar"} -@pytest.mark.requires("reka") -def test_reka_invalid_model_kwargs() -> None: - with pytest.raises(ValueError): - ChatReka(model_kwargs={"max_tokens": "invalid"}) - - @pytest.mark.requires("reka") def test_reka_incorrect_field() -> None: - with pytest.warns(match="not default parameter"): - llm = ChatReka(foo="bar") - assert llm.model_kwargs == {"foo": "bar"} + """Test that providing an incorrect field raises ValidationError.""" + with pytest.raises(ValidationError): + ChatReka(unknown_field="bar") # type: ignore @pytest.mark.requires("reka") @@ -62,11 +49,14 @@ def test_reka_initialization() -> None: ("Hello", [{"type": "text", "text": "Hello"}]), ( [ - {"type": "text", "text": "Hello"}, - {"type": "image_url", "image_url": "https://example.com/image.jpg"}, + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, ], [ - {"type": "text", "text": "Hello"}, + {"type": "text", "text": "Describe this image"}, {"type": "image_url", "image_url": "https://example.com/image.jpg"}, ], ), @@ -85,7 +75,7 @@ def test_reka_initialization() -> None: ), ], ) -def test_process_content(content, expected) -> None: +def test_process_content(content: Any, expected: List[Dict[str, Any]]) -> None: result = process_content(content) assert result == expected @@ -132,7 +122,7 @@ def test_process_content(content, expected) -> None: ], ) def test_convert_to_reka_messages( - messages: List[BaseMessage], expected: List[dict] + messages: List[BaseMessage], expected: List[Dict[str, Any]] ) -> None: result = convert_to_reka_messages(messages) assert result == expected @@ -167,12 +157,14 @@ def test_reka_default_params() -> None: @pytest.mark.requires("reka") def test_reka_identifying_params() -> None: - llm = ChatReka(temperature=0.7) - assert llm._identifying_params == { - "max_tokens": 256, + """Test that ChatReka identifies its default parameters correctly.""" + chat = ChatReka(model="reka-flash", temperature=0.7, max_tokens=256) + expected_params = { "model": "reka-flash", "temperature": 0.7, + "max_tokens": 256, } + assert chat._default_params == expected_params @pytest.mark.requires("reka") @@ -200,7 +192,7 @@ def test_reka_tool_use_with_mocked_response() -> None: mock_chat.create.return_value = mock_response llm = ChatReka() - messages = [HumanMessage(content="Tell me about LangChain")] + messages: List[BaseMessage] = [HumanMessage(content="Tell me about LangChain")] result = llm._generate(messages) assert len(result.generations) == 1 diff --git a/libs/community/tests/unit_tests/chat_models/test_reka_search.py b/libs/community/tests/unit_tests/chat_models/test_reka_search.py deleted file mode 100644 index ed3028cdc1230..0000000000000 --- a/libs/community/tests/unit_tests/chat_models/test_reka_search.py +++ /dev/null @@ -1,42 +0,0 @@ -from dotenv import load_dotenv -from langchain import hub -from langchain.agents import AgentExecutor, create_self_ask_with_search_agent -from langchain.chat_models import ChatReka -from langchain_core.tools import Tool - -from langchain_community.utilities import GoogleSerperAPIWrapper - -# Set up API keys -# os.environ["SERPER_API_KEY"] = "your_serper_api_key_here" -load_dotenv() - -# Initialize ChatReka -chat_reka = ChatReka( - model="reka-core", - temperature=0.4, -) -prompt = hub.pull("hwchase17/self-ask-with-search") -# Initialize Google Serper API Wrapper -search = GoogleSerperAPIWrapper() - -# Define tools -tools = [ - Tool( - name="Intermediate Answer", - func=search.run, - description=""" - useful for when you need to ask with search. - """, - ) -] - -# Initialize the agent -agent = create_self_ask_with_search_agent(chat_reka, tools, prompt) - -agent_executor = AgentExecutor(agent=agent, tools=tools) - - -# Example usage -if __name__ == "__main__": - query = "What is the hometown of the reigning men's U.S. Open champion?" - agent_executor.invoke({"input": "query"}) From 1d25d1391dbee1adad5b76367473ac0110c77770 Mon Sep 17 00:00:00 2001 From: findalexli Date: Tue, 15 Oct 2024 15:03:13 -0700 Subject: [PATCH 10/36] Update notebook with tool use example --- docs/docs/integrations/chat/reka.ipynb | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb index 5528c8476dd7e..6e9f11a78c0f1 100644 --- a/docs/docs/integrations/chat/reka.ipynb +++ b/docs/docs/integrations/chat/reka.ipynb @@ -59,7 +59,7 @@ "### Model features\n", "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", - "| ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | \n", + "| ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | \n", "\n", "## Setup\n", "\n", @@ -136,7 +136,6 @@ "outputs": [], "source": [ "from langchain_community.chat_models import ChatReka\n", - "# This \n", "model = ChatReka()" ] }, From 6471ad14c9cd38bab8fe2d0586dd6d0b3e87d126 Mon Sep 17 00:00:00 2001 From: findalexli Date: Tue, 15 Oct 2024 15:26:38 -0700 Subject: [PATCH 11/36] Tool use enfoce a version --- libs/community/extended_testing_deps.txt | 4 ++-- .../langchain_community/chat_models/reka.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index cc2f9f6da8a2e..ae9d84c9bc366 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -73,7 +73,7 @@ rapidfuzz>=3.1.1,<4 rapidocr-onnxruntime>=1.3.2,<2 rdflib==7.0.0 requests-toolbelt>=1.0.0,<2 -reka-api>=3.2.0 +reka-api>=3.2.0,<4 rspace_client>=2.5.0,<3 scikit-learn>=1.2.2,<2 simsimd>=5.0.0,<6 @@ -83,7 +83,7 @@ sseclient-py>=1.8.0,<2 streamlit>=1.18.0,<2 sympy>=1.12,<2 telethon>=1.28.5,<2 -tidb-vector>=0.0.3,<1.0.0 +tidb-vector>=0.0.3,<1.0.0 timescale-vector==0.0.1 tqdm>=4.48.0 tree-sitter>=0.20.2,<0.21 diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index ece04ac5fd4f3..65c82110dd576 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -1,4 +1,5 @@ import json +import warnings from typing import ( Any, AsyncIterator, @@ -6,7 +7,6 @@ Dict, Iterator, List, - Literal, Mapping, Optional, Sequence, @@ -377,7 +377,7 @@ def bind_tools( self, tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], *, - tool_choice: Literal["auto", "none", "tool"] = "auto", + tool_choice: str = "auto", strict: Optional[bool] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: @@ -417,7 +417,13 @@ def bind_tools( # Ensure tool_choice is one of the allowed options if tool_choice not in ("auto", "none", "tool"): - raise ValueError("Must be one of 'auto', 'none', or 'tool'.") + warnings.warn( + f"Invalid tool_choice '{tool_choice}' provided. " + "Reka model cannot be forced to use this tool name. " + "Defaulting to 'tool', which will force the model " + "to invoke one or more of the tools it has been passed." + ) + tool_choice = "tool" # Map tool_choice to the parameter expected by the Reka API kwargs["tool_choice"] = tool_choice From ca4cf2e75bbb8a2026824d315b3273934c7fe6e8 Mon Sep 17 00:00:00 2001 From: findalexli Date: Wed, 16 Oct 2024 14:16:47 -0700 Subject: [PATCH 12/36] use pytest mark requires reka-api --- .../tests/unit_tests/chat_models/test_reka.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index ba0b2665ab359..7b235cb88bfaf 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -16,26 +16,26 @@ os.environ["REKA_API_KEY"] = "dummy_key" -@pytest.mark.requires("reka") +@pytest.mark.requires("reka-api") def test_reka_model_param() -> None: llm = ChatReka(model="reka-flash") assert llm.model == "reka-flash" -@pytest.mark.requires("reka") +@pytest.mark.requires("reka-api") def test_reka_model_kwargs() -> None: llm = ChatReka(model_kwargs={"foo": "bar"}) assert llm.model_kwargs == {"foo": "bar"} -@pytest.mark.requires("reka") +@pytest.mark.requires("reka-api") def test_reka_incorrect_field() -> None: """Test that providing an incorrect field raises ValidationError.""" with pytest.raises(ValidationError): ChatReka(unknown_field="bar") # type: ignore -@pytest.mark.requires("reka") +@pytest.mark.requires("reka-api") def test_reka_initialization() -> None: """Test Reka initialization.""" # Verify that ChatReka can be initialized using a secret key provided @@ -43,6 +43,7 @@ def test_reka_initialization() -> None: ChatReka(model="reka-flash", reka_api_key="test_key") +@pytest.mark.requires("reka-api") @pytest.mark.parametrize( ("content", "expected"), [ @@ -80,6 +81,7 @@ def test_process_content(content: Any, expected: List[Dict[str, Any]]) -> None: assert result == expected +@pytest.mark.requires("reka-api") @pytest.mark.parametrize( ("messages", "expected"), [ @@ -128,25 +130,25 @@ def test_convert_to_reka_messages( assert result == expected -@pytest.mark.requires("reka") +@pytest.mark.requires("reka-api") def test_reka_streaming() -> None: llm = ChatReka(streaming=True) assert llm.streaming is True -@pytest.mark.requires("reka") +@pytest.mark.requires("reka-api") def test_reka_temperature() -> None: llm = ChatReka(temperature=0.5) assert llm.temperature == 0.5 -@pytest.mark.requires("reka") +@pytest.mark.requires("reka-api") def test_reka_max_tokens() -> None: llm = ChatReka(max_tokens=100) assert llm.max_tokens == 100 -@pytest.mark.requires("reka") +@pytest.mark.requires("reka-api") def test_reka_default_params() -> None: llm = ChatReka() assert llm._default_params == { @@ -155,7 +157,7 @@ def test_reka_default_params() -> None: } -@pytest.mark.requires("reka") +@pytest.mark.requires("reka-api") def test_reka_identifying_params() -> None: """Test that ChatReka identifies its default parameters correctly.""" chat = ChatReka(model="reka-flash", temperature=0.7, max_tokens=256) @@ -167,13 +169,13 @@ def test_reka_identifying_params() -> None: assert chat._default_params == expected_params -@pytest.mark.requires("reka") +@pytest.mark.requires("reka-api") def test_reka_llm_type() -> None: llm = ChatReka() assert llm._llm_type == "reka-chat" -@pytest.mark.requires("reka") +@pytest.mark.requires("reka-api") def test_reka_tool_use_with_mocked_response() -> None: with patch("langchain_community.chat_models.reka.Reka") as MockReka: # Mock the Reka client From 142f4ca4fb28e506b1b0515653b1ebd2587e5a40 Mon Sep 17 00:00:00 2001 From: findalexli Date: Wed, 16 Oct 2024 15:19:59 -0700 Subject: [PATCH 13/36] Lint and mark test requring reka-api --- docs/docs/integrations/chat/reka.ipynb | 11 ++++++++--- .../tests/integration_tests/chat_models/test_reka.py | 6 ++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb index 6e9f11a78c0f1..d0291238bbe54 100644 --- a/docs/docs/integrations/chat/reka.ipynb +++ b/docs/docs/integrations/chat/reka.ipynb @@ -136,6 +136,7 @@ "outputs": [], "source": [ "from langchain_community.chat_models import ChatReka\n", + "\n", "model = ChatReka()" ] }, @@ -163,7 +164,7 @@ } ], "source": [ - "model.invoke('hi')" + "model.invoke(\"hi\")" ] }, { @@ -237,11 +238,15 @@ " {\"type\": \"text\", \"text\": \"What are the difference between the two images? \"},\n", " {\n", " \"type\": \"image_url\",\n", - " \"image_url\": {\"url\": \"https://cdn.pixabay.com/photo/2019/07/23/13/51/shepherd-dog-4357790_1280.jpg\"},\n", + " \"image_url\": {\n", + " \"url\": \"https://cdn.pixabay.com/photo/2019/07/23/13/51/shepherd-dog-4357790_1280.jpg\"\n", + " },\n", " },\n", " {\n", " \"type\": \"image_url\",\n", - " \"image_url\": {\"url\": \"https://cdn.pixabay.com/photo/2024/02/17/00/18/cat-8578562_1280.jpg\"},\n", + " \"image_url\": {\n", + " \"url\": \"https://cdn.pixabay.com/photo/2024/02/17/00/18/cat-8578562_1280.jpg\"\n", + " },\n", " },\n", " ],\n", ")\n", diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index d1aacc252b2a9..14f4a3b9e83a8 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) +@pytest.mark.requires("reka-api") @pytest.mark.scheduled def test_reka_call() -> None: """Test a simple call to Reka.""" @@ -30,6 +31,7 @@ def test_reka_call() -> None: logger.debug(f"Response content: {response.content}") +@pytest.mark.requires("reka-api") @pytest.mark.scheduled def test_reka_generate() -> None: """Test the generate method of Reka.""" @@ -48,6 +50,7 @@ def test_reka_generate() -> None: assert chat_messages == messages_copy +@pytest.mark.requires("reka-api") @pytest.mark.scheduled def test_reka_streaming() -> None: """Test streaming tokens from Reka.""" @@ -59,6 +62,7 @@ def test_reka_streaming() -> None: logger.debug(f"Streaming response content: {response.content}") +@pytest.mark.requires("reka-api") @pytest.mark.scheduled def test_reka_streaming_callback() -> None: """Test that streaming correctly invokes callbacks.""" @@ -75,6 +79,7 @@ def test_reka_streaming_callback() -> None: logger.debug(f"Number of LLM streams: {callback_handler.llm_streams}") +@pytest.mark.requires("reka-api") @pytest.mark.scheduled async def test_reka_async_streaming_callback() -> None: """Test asynchronous streaming with callbacks.""" @@ -98,6 +103,7 @@ async def test_reka_async_streaming_callback() -> None: logger.debug(f"Async generated response: {response.text}") +@pytest.mark.requires("reka-api") @pytest.mark.scheduled def test_reka_tool_usage_integration() -> None: """Test tool usage with Reka API integration.""" From 43972a8215c7928e0e914808a13b90b5767f514d Mon Sep 17 00:00:00 2001 From: findalexli Date: Wed, 16 Oct 2024 15:32:56 -0700 Subject: [PATCH 14/36] Update reka import in notebook --- docs/docs/integrations/chat/reka.ipynb | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb index d0291238bbe54..acbe81e0365dd 100644 --- a/docs/docs/integrations/chat/reka.ipynb +++ b/docs/docs/integrations/chat/reka.ipynb @@ -189,7 +189,6 @@ ], "source": [ "from langchain_community.chat_models import ChatReka\n", - "import httpx\n", "\n", "model = ChatReka()\n", "image_url = \"https://v0.docs.reka.ai/_images/000000245576.jpg\"\n", @@ -270,7 +269,7 @@ "from dotenv import load_dotenv\n", "from langchain import hub\n", "from langchain.agents import AgentExecutor, create_self_ask_with_search_agent\n", - "from langchain.chat_models import ChatReka\n", + "from langchain_community.chat_models import ChatReka\n", "from langchain_core.tools import Tool\n", "\n", "from langchain_community.utilities import GoogleSerperAPIWrapper\n", From eacde83f1c715b30dd10f9bb42e202c88e74d5ad Mon Sep 17 00:00:00 2001 From: findalexli Date: Wed, 16 Oct 2024 16:47:57 -0700 Subject: [PATCH 15/36] Import order update to pass unit test --- docs/docs/integrations/chat/reka.ipynb | 3 +- .../langchain_community/chat_models/reka.py | 58 ++++++------------- 2 files changed, 19 insertions(+), 42 deletions(-) diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb index acbe81e0365dd..05dd1e9e4c581 100644 --- a/docs/docs/integrations/chat/reka.ipynb +++ b/docs/docs/integrations/chat/reka.ipynb @@ -270,9 +270,8 @@ "from langchain import hub\n", "from langchain.agents import AgentExecutor, create_self_ask_with_search_agent\n", "from langchain_community.chat_models import ChatReka\n", - "from langchain_core.tools import Tool\n", - "\n", "from langchain_community.utilities import GoogleSerperAPIWrapper\n", + "from langchain_core.tools import Tool\n", "\n", "# Set up API keys\n", "# os.environ[\"SERPER_API_KEY\"] = \"your_serper_api_key_here\"\n", diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index 65c82110dd576..b6f81e09f607a 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -35,20 +35,9 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool -from langchain_core.utils import ( - get_from_dict_or_env, -) +from langchain_core.utils import get_from_dict_or_env from langchain_core.utils.function_calling import convert_to_openai_tool -from pydantic import ConfigDict, Field, model_validator - -try: - from reka import ChatMessage as RekaChatMessage - from reka import ToolCall - from reka.client import AsyncReka, Reka -except ImportError: - raise ValueError( - "Reka is not installed. Please install it with `pip install reka-api`." - ) +from pydantic import BaseModel, ConfigDict, Field, model_validator REKA_MODELS = [ "reka-edge", @@ -131,11 +120,11 @@ def convert_to_reka_messages(messages: List[BaseMessage]) -> List[Dict[str, Any] tool_calls = message.additional_kwargs["tool_calls"] formatted_tool_calls = [] for tool_call in tool_calls: - formatted_tool_call = ToolCall( - id=tool_call["id"], - name=tool_call["function"]["name"], - parameters=json.loads(tool_call["function"]["arguments"]), - ) + formatted_tool_call = { + "id": tool_call["id"], + "name": tool_call["function"]["name"], + "parameters": json.loads(tool_call["function"]["arguments"]), + } formatted_tool_calls.append(formatted_tool_call) reka_message["tool_calls"] = formatted_tool_calls reka_messages.append(reka_message) @@ -153,9 +142,6 @@ def convert_to_reka_messages(messages: List[BaseMessage]) -> List[Dict[str, Any] "content": content_list, } ) - elif isinstance(message, RekaChatMessage): - processed_content = process_content(message.content) - reka_messages.append({"role": message.role, "content": processed_content}) else: raise ValueError(f"Unsupported message type: {type(message)}") @@ -175,17 +161,6 @@ class RekaCommon(BaseLanguageModel): count_tokens: Optional[Callable[[str], int]] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) - # @model_validator(mode="before") - # @classmethod - # def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: - # extra = values.get("model_kwargs", {}) - # all_field_names = set(cls.model_fields) - # for field in list(values.keys()): - # if field not in all_field_names: - # extra[field] = values.pop(field) - # values["model_kwargs"] = extra - # return values - @model_validator(mode="before") @classmethod def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: @@ -197,6 +172,9 @@ def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["reka_api_key"] = reka_api_key try: + # Import reka libraries here + from reka.client import AsyncReka, Reka + values["client"] = Reka( api_key=reka_api_key, ) @@ -253,10 +231,10 @@ def _stream( for chunk in stream: content = chunk.responses[0].chunk.content - chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) - yield chunk + chat_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) if run_manager: - run_manager.on_llm_new_token(content, chunk=chunk) + run_manager.on_llm_new_token(content, chunk=chat_chunk) + yield chat_chunk async def _astream( self, @@ -274,10 +252,10 @@ async def _astream( async for chunk in stream: content = chunk.responses[0].chunk.content - chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) - yield chunk + chat_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) if run_manager: - await run_manager.on_llm_new_token(content, chunk=chunk) + await run_manager.on_llm_new_token(content, chunk=chat_chunk) + yield chat_chunk def _generate( self, @@ -375,7 +353,7 @@ def get_num_tokens(self, text: str) -> int: def bind_tools( self, - tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], *, tool_choice: str = "auto", strict: Optional[bool] = None, @@ -430,4 +408,4 @@ def bind_tools( # Pass the tools and updated kwargs to the model formatted_tools = [tool["function"] for tool in formatted_tools] - return super().bind(tools=formatted_tools, **kwargs) + return super().bind(tools=formatted_tools, **kwargs) \ No newline at end of file From 2e9439346673c54b33daf364d5c9bfd5830c53fc Mon Sep 17 00:00:00 2001 From: vbarda Date: Thu, 24 Oct 2024 17:30:11 -0400 Subject: [PATCH 16/36] lint --- libs/community/langchain_community/chat_models/reka.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index b6f81e09f607a..0d6a7d0da58a4 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -408,4 +408,4 @@ def bind_tools( # Pass the tools and updated kwargs to the model formatted_tools = [tool["function"] for tool in formatted_tools] - return super().bind(tools=formatted_tools, **kwargs) \ No newline at end of file + return super().bind(tools=formatted_tools, **kwargs) From 72446fa233bef06cf08050b97c357d45fa6a43c0 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Thu, 24 Oct 2024 17:32:25 -0400 Subject: [PATCH 17/36] remove extra whitespace --- libs/community/extended_testing_deps.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index f98771f038f0c..6b101f925ef89 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -83,7 +83,7 @@ sseclient-py>=1.8.0,<2 streamlit>=1.18.0,<2 sympy>=1.12,<2 telethon>=1.28.5,<2 -tidb-vector>=0.0.3,<1.0.0 +tidb-vector>=0.0.3,<1.0.0 timescale-vector==0.0.1 tqdm>=4.48.0 tree-sitter>=0.20.2,<0.21 From 3bdc7255a11e4600ec1f417a9e027130525a1ec2 Mon Sep 17 00:00:00 2001 From: findalexli Date: Fri, 25 Oct 2024 10:30:55 -0700 Subject: [PATCH 18/36] Remove unused model list --- libs/community/langchain_community/chat_models/reka.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index 0d6a7d0da58a4..57b809e08a189 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -39,12 +39,6 @@ from langchain_core.utils.function_calling import convert_to_openai_tool from pydantic import BaseModel, ConfigDict, Field, model_validator -REKA_MODELS = [ - "reka-edge", - "reka-flash", - "reka-core", -] - DEFAULT_REKA_MODEL = "reka-flash" ContentType = Union[str, List[Union[str, Dict[str, Any]]]] From e0c7eec65a55e2fb235c232bed0b7046ffa57e8e Mon Sep 17 00:00:00 2001 From: findalexli Date: Fri, 1 Nov 2024 10:19:37 -0700 Subject: [PATCH 19/36] Combine reka chat into one class --- .../langchain_community/chat_models/reka.py | 22 ++++----------- .../chat_models/test_reka.py | 12 ++++---- .../tests/unit_tests/chat_models/test_reka.py | 28 +++++++++---------- 3 files changed, 25 insertions(+), 37 deletions(-) diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index 57b809e08a189..cb98f6850ccd4 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -141,8 +141,9 @@ def convert_to_reka_messages(messages: List[BaseMessage]) -> List[Dict[str, Any] return reka_messages +class ChatReka(BaseChatModel): + """Reka chat large language models.""" -class RekaCommon(BaseLanguageModel): client: Any = None #: :meta private: async_client: Any = None #: :meta private: model: str = Field(default=DEFAULT_REKA_MODEL) @@ -154,6 +155,7 @@ class RekaCommon(BaseLanguageModel): reka_api_key: Optional[str] = None count_tokens: Optional[Callable[[str], int]] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) + model_config = ConfigDict(extra="forbid") @model_validator(mode="before") @classmethod @@ -193,17 +195,6 @@ def _default_params(self) -> Mapping[str, Any]: params["temperature"] = self.temperature return {**params, **self.model_kwargs} - -class ChatReka(BaseChatModel, RekaCommon): - """Reka chat large language models.""" - - model: str = Field(default=DEFAULT_REKA_MODEL) - reka_api_key: Optional[str] = None - model_kwargs: Dict[str, Any] = Field(default_factory=dict) - temperature: Optional[float] = None - max_tokens: int = Field(default=256) - model_config = ConfigDict(extra="forbid") - @property def _llm_type(self) -> str: """Return type of chat model.""" @@ -389,13 +380,10 @@ def bind_tools( # Ensure tool_choice is one of the allowed options if tool_choice not in ("auto", "none", "tool"): - warnings.warn( + raise ValueError( f"Invalid tool_choice '{tool_choice}' provided. " - "Reka model cannot be forced to use this tool name. " - "Defaulting to 'tool', which will force the model " - "to invoke one or more of the tools it has been passed." + "Tool choice must be one of: 'auto', 'none', or 'tool'." ) - tool_choice = "tool" # Map tool_choice to the parameter expected by the Reka API kwargs["tool_choice"] = tool_choice diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index 14f4a3b9e83a8..870e09f041144 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") @pytest.mark.scheduled def test_reka_call() -> None: """Test a simple call to Reka.""" @@ -31,7 +31,7 @@ def test_reka_call() -> None: logger.debug(f"Response content: {response.content}") -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") @pytest.mark.scheduled def test_reka_generate() -> None: """Test the generate method of Reka.""" @@ -50,7 +50,7 @@ def test_reka_generate() -> None: assert chat_messages == messages_copy -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") @pytest.mark.scheduled def test_reka_streaming() -> None: """Test streaming tokens from Reka.""" @@ -62,7 +62,7 @@ def test_reka_streaming() -> None: logger.debug(f"Streaming response content: {response.content}") -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") @pytest.mark.scheduled def test_reka_streaming_callback() -> None: """Test that streaming correctly invokes callbacks.""" @@ -79,7 +79,7 @@ def test_reka_streaming_callback() -> None: logger.debug(f"Number of LLM streams: {callback_handler.llm_streams}") -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") @pytest.mark.scheduled async def test_reka_async_streaming_callback() -> None: """Test asynchronous streaming with callbacks.""" @@ -103,7 +103,7 @@ async def test_reka_async_streaming_callback() -> None: logger.debug(f"Async generated response: {response.text}") -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") @pytest.mark.scheduled def test_reka_tool_usage_integration() -> None: """Test tool usage with Reka API integration.""" diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 7b235cb88bfaf..876bdefa84a1e 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -16,26 +16,26 @@ os.environ["REKA_API_KEY"] = "dummy_key" -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") def test_reka_model_param() -> None: llm = ChatReka(model="reka-flash") assert llm.model == "reka-flash" -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") def test_reka_model_kwargs() -> None: llm = ChatReka(model_kwargs={"foo": "bar"}) assert llm.model_kwargs == {"foo": "bar"} -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") def test_reka_incorrect_field() -> None: """Test that providing an incorrect field raises ValidationError.""" with pytest.raises(ValidationError): ChatReka(unknown_field="bar") # type: ignore -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") def test_reka_initialization() -> None: """Test Reka initialization.""" # Verify that ChatReka can be initialized using a secret key provided @@ -43,7 +43,7 @@ def test_reka_initialization() -> None: ChatReka(model="reka-flash", reka_api_key="test_key") -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") @pytest.mark.parametrize( ("content", "expected"), [ @@ -81,7 +81,7 @@ def test_process_content(content: Any, expected: List[Dict[str, Any]]) -> None: assert result == expected -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") @pytest.mark.parametrize( ("messages", "expected"), [ @@ -130,25 +130,25 @@ def test_convert_to_reka_messages( assert result == expected -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") def test_reka_streaming() -> None: llm = ChatReka(streaming=True) assert llm.streaming is True -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") def test_reka_temperature() -> None: llm = ChatReka(temperature=0.5) assert llm.temperature == 0.5 -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") def test_reka_max_tokens() -> None: llm = ChatReka(max_tokens=100) assert llm.max_tokens == 100 -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") def test_reka_default_params() -> None: llm = ChatReka() assert llm._default_params == { @@ -157,7 +157,7 @@ def test_reka_default_params() -> None: } -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") def test_reka_identifying_params() -> None: """Test that ChatReka identifies its default parameters correctly.""" chat = ChatReka(model="reka-flash", temperature=0.7, max_tokens=256) @@ -169,15 +169,15 @@ def test_reka_identifying_params() -> None: assert chat._default_params == expected_params -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") def test_reka_llm_type() -> None: llm = ChatReka() assert llm._llm_type == "reka-chat" -@pytest.mark.requires("reka-api") +@pytest.mark.requires("reka") def test_reka_tool_use_with_mocked_response() -> None: - with patch("langchain_community.chat_models.reka.Reka") as MockReka: + with patch("reka.client.Reka") as MockReka: # Mock the Reka client mock_client = MockReka.return_value mock_chat = MagicMock() From 5343ec5520c15babb7f7c603a925f8f31f37a26d Mon Sep 17 00:00:00 2001 From: findalexli Date: Fri, 1 Nov 2024 10:26:48 -0700 Subject: [PATCH 20/36] Unit and integration test with system messages --- .../chat_models/test_reka.py | 40 +++++++ .../tests/unit_tests/chat_models/test_reka.py | 105 +++++++++++++++++- 2 files changed, 144 insertions(+), 1 deletion(-) diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index 870e09f041144..fe4a6acc6f170 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -9,6 +9,7 @@ BaseMessage, HumanMessage, ToolMessage, + SystemMessage, ) from langchain_core.outputs import ChatGeneration, LLMResult @@ -172,3 +173,42 @@ def test_reka_tool_usage_integration() -> None: assert final_response.content, "The final response content is empty." else: pytest.fail("The model did not request a tool.") + + +@pytest.mark.requires("reka") +@pytest.mark.scheduled +def test_reka_system_message() -> None: + """Test Reka with system message.""" + chat = ChatReka(model="reka-flash", verbose=True) + messages = [ + SystemMessage(content="You are a helpful AI that speaks like Shakespeare."), + HumanMessage(content="Tell me about the weather today.") + ] + response = chat.invoke(messages) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + logger.debug(f"Response with system message: {response.content}") + + +@pytest.mark.requires("reka") +@pytest.mark.scheduled +def test_reka_system_message_multi_turn() -> None: + """Test multi-turn conversation with system message.""" + chat = ChatReka(model="reka-flash", verbose=True) + messages = [ + SystemMessage(content="You are a math tutor who explains concepts simply."), + HumanMessage(content="What is a prime number?"), + ] + + # First turn + response1 = chat.invoke(messages) + assert isinstance(response1, AIMessage) + messages.append(response1) + + # Second turn + messages.append(HumanMessage(content="Can you give me an example?")) + response2 = chat.invoke(messages) + assert isinstance(response2, AIMessage) + + logger.debug(f"First response: {response1.content}") + logger.debug(f"Second response: {response2.content}") \ No newline at end of file diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 876bdefa84a1e..7f83fa67f27ff 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch import pytest -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from pydantic import ValidationError from langchain_community.chat_models import ChatReka @@ -208,3 +208,106 @@ def test_reka_tool_use_with_mocked_response() -> None: assert tool_calls[0]["function"]["arguments"] == json.dumps( {"query": "LangChain"} ) + + +@pytest.mark.requires("reka") +@pytest.mark.parametrize( + ("messages", "expected"), + [ + # Test single system message + ( + [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Hello"), + ], + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant.\nHello" + } + ] + } + ], + ), + # Test system message with multiple messages + ( + [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="What is 2+2?"), + AIMessage(content="4"), + HumanMessage(content="Thanks!"), + ], + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant.\nWhat is 2+2?" + } + ] + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "4"}] + }, + { + "role": "user", + "content": [{"type": "text", "text": "Thanks!"}] + }, + ], + ), + # Test system message with media content + ( + [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage( + content=[ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ] + ), + ], + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant.\nWhat's in this image?" + }, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ] + }, + ], + ), + ], +) +def test_system_message_handling( + messages: List[BaseMessage], expected: List[Dict[str, Any]] +) -> None: + """Test that system messages are handled correctly.""" + result = convert_to_reka_messages(messages) + assert result == expected + + +@pytest.mark.requires("reka") +def test_multiple_system_messages_error() -> None: + """Test that multiple system messages raise an error.""" + messages = [ + SystemMessage(content="System message 1"), + SystemMessage(content="System message 2"), + HumanMessage(content="Hello"), + ] + + with pytest.raises(ValueError, match="Multiple system messages are not supported."): + convert_to_reka_messages(messages) \ No newline at end of file From 541f0aa39077ce1765d1474c735465ff2001b7f3 Mon Sep 17 00:00:00 2001 From: findalexli Date: Fri, 1 Nov 2024 11:30:33 -0700 Subject: [PATCH 21/36] Notebook doc update (pip, e2e agent) --- docs/docs/integrations/chat/reka.ipynb | 370 +++++++++++++++++++------ 1 file changed, 285 insertions(+), 85 deletions(-) diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb index 05dd1e9e4c581..7cee7e9824838 100644 --- a/docs/docs/integrations/chat/reka.ipynb +++ b/docs/docs/integrations/chat/reka.ipynb @@ -1,32 +1,5 @@ { "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_community.chat_models import ChatReka" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "model = ChatReka()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "image_url = \"https://v0.docs.reka.ai/_images/000000245576.jpg\"" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -88,17 +61,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "langchain-experimental 0.0.65 requires langchain-community<0.3.0,>=0.2.16, but you have langchain-community 0.3.2 which is incompatible.\n", - "langchain-experimental 0.0.65 requires langchain-core<0.3.0,>=0.2.38, but you have langchain-core 0.3.10 which is incompatible.\n", - "langchain-openai 0.1.23 requires langchain-core<0.3.0,>=0.2.35, but you have langchain-core 0.3.10 which is incompatible.\n", - "langchain-standard-tests 0.1.1 requires langchain-core<0.3,>=0.1.40, but you have langchain-core 0.3.10 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" + "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ - "%pip install -qU langchain_community" + "%pip install -qU langchain_community reka-api" ] }, { @@ -110,14 +78,21 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import getpass\n", "import os\n", "\n", - "# os.environ[\"REKA_API_KEY\"] = getpass.getpass(\"Enter your Reka API key: \")" + "os.environ[\"REKA_API_KEY\"] = getpass.getpass(\"Enter your Reka API key: \")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Optional: use Langsmith to trace the execution of the model" ] }, { @@ -126,12 +101,16 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install -qU langchain_community" + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n", + "os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass('Enter your Langsmith API key: ')" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -149,16 +128,16 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AIMessage(content=' Hello! How can I help you today? If you have a question, need assistance, or just want to chat, feel free to let me know. Have a great day!\\n\\n', additional_kwargs={}, response_metadata={}, id='run-206ce7e0-7c6c-4d81-b66b-8b98cb1232cc-0')" + "AIMessage(content=' Hello! How can I help you today? If you have a question, need assistance, or just want to chat, feel free to let me know. Have a great day!\\n\\n', additional_kwargs={}, response_metadata={}, id='run-b40e505a-5110-451a-92e6-a2a34988472c-0')" ] }, - "execution_count": 6, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -176,23 +155,21 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " The image shows an indoor setting with no visible weather conditions. The focus is on a ginger cat inspecting a computer keyboard. There are no windows or natural light sources that would provide information about the weather outside. The environment is a typical home office setup with a desk, computer, and some other items like a pen holder and a mobile phone.\n" + " The image shows an indoor setting with no visible weather elements. It features a cat on a desk licking a computer keyboard. The background includes a computer monitor, a desk with a few items like a pen holder and a mobile phone, and a glimpse of a window with blinds partially drawn.\n" ] } ], "source": [ - "from langchain_community.chat_models import ChatReka\n", + "from langchain_core.messages import HumanMessage\n", "\n", - "model = ChatReka()\n", "image_url = \"https://v0.docs.reka.ai/_images/000000245576.jpg\"\n", - "from langchain_core.messages import HumanMessage\n", "\n", "message = HumanMessage(\n", " content=[\n", @@ -216,18 +193,18 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " The first image shows two German Shepherd dogs, one adult and one puppy, running through grass. The adult dog is carrying a large stick in its mouth, suggesting play or exercise, while the puppy follows close behind. Both are in a dynamic, natural outdoor setting with lush greenery.\n", + " The first image shows two German Shepherds, one adult and one puppy, in a grassy field. The adult dog is carrying a large stick in its mouth, indicating playfulness or a game being played. The background features a natural, leafy environment, suggesting an outdoor setting conducive to activities like running or training.\n", "\n", - "The second image features a close-up of a single, adult Siamese cat with striking blue eyes, sitting in a natural setting that appears to be outdoors with dried leaves or grass around. The cat's expression is calm and focused, and its fur is predominantly light-colored with darker points on its ears, face, and tail.\n", + "The second image features a close-up of a single cat with striking blue eyes, set against a background of dry leaves or grass. The cat has a calm and somewhat intense expression, with its fur neatly groomed and whiskers prominently visible. The focus is on the cat's face, capturing its serene demeanor in a quiet, natural outdoor setting.\n", "\n", - "The key differences between the images are the subjects (dogs vs. cat) and their expressions (playful vs. calm and focused). Additionally, the settings are similar yet distinct, with both animals in natural environments but the first image depicting more active engagement with the surroundings, while the second is more still and serene.\n" + "The main differences lie in the subjects (dogs vs. cat) and their expressions (playful vs. serene), as well as the composition and focus of the images (outdoor play vs. close-up portrait).\n" ] } ], @@ -257,7 +234,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Tool call example with self-ask-with-search agent" + "Use use with tavtly api search" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tool use and agent creation\n", + "\n", + "## Define the tools\n", + "\n", + "We first need to create the tools we want to use. Our main tool of choice will be Tavily - a search engine. We have a built-in tool in LangChain to easily use Tavily search engine as tool.\n", + "\n" ] }, { @@ -266,47 +255,258 @@ "metadata": {}, "outputs": [], "source": [ - "from dotenv import load_dotenv\n", - "from langchain import hub\n", - "from langchain.agents import AgentExecutor, create_self_ask_with_search_agent\n", - "from langchain_community.chat_models import ChatReka\n", - "from langchain_community.utilities import GoogleSerperAPIWrapper\n", - "from langchain_core.tools import Tool\n", + "import getpass\n", + "import os\n", "\n", - "# Set up API keys\n", - "# os.environ[\"SERPER_API_KEY\"] = \"your_serper_api_key_here\"\n", - "load_dotenv()\n", + "os.environ[\"TAVILY_API_KEY\"] = getpass.getpass('Enter your Tavily API key: ')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'url': 'https://www.weatherapi.com/', 'content': \"{'location': {'name': 'San Francisco', 'region': 'California', 'country': 'United States of America', 'lat': 37.775, 'lon': -122.4183, 'tz_id': 'America/Los_Angeles', 'localtime_epoch': 1730484342, 'localtime': '2024-11-01 11:05'}, 'current': {'last_updated_epoch': 1730484000, 'last_updated': '2024-11-01 11:00', 'temp_c': 11.1, 'temp_f': 52.0, 'is_day': 1, 'condition': {'text': 'Mist', 'icon': '//cdn.weatherapi.com/weather/64x64/day/143.png', 'code': 1030}, 'wind_mph': 2.9, 'wind_kph': 4.7, 'wind_degree': 247, 'wind_dir': 'WSW', 'pressure_mb': 1019.0, 'pressure_in': 30.08, 'precip_mm': 0.0, 'precip_in': 0.0, 'humidity': 100, 'cloud': 100, 'feelslike_c': 11.1, 'feelslike_f': 52.0, 'windchill_c': 10.3, 'windchill_f': 50.5, 'heatindex_c': 10.8, 'heatindex_f': 51.5, 'dewpoint_c': 10.4, 'dewpoint_f': 50.6, 'vis_km': 2.8, 'vis_miles': 1.0, 'uv': 3.0, 'gust_mph': 3.8, 'gust_kph': 6.1}}\"}, {'url': 'https://weatherspark.com/h/m/557/2024/1/Historical-Weather-in-January-2024-in-San-Francisco-California-United-States', 'content': 'San Francisco Temperature History January 2024\\nHourly Temperature in January 2024 in San Francisco\\nCompare San Francisco to another city:\\nCloud Cover in January 2024 in San Francisco\\nDaily Precipitation in January 2024 in San Francisco\\nObserved Weather in January 2024 in San Francisco\\nHours of Daylight and Twilight in January 2024 in San Francisco\\nSunrise & Sunset with Twilight in January 2024 in San Francisco\\nSolar Elevation and Azimuth in January 2024 in San Francisco\\nMoon Rise, Set & Phases in January 2024 in San Francisco\\nHumidity Comfort Levels in January 2024 in San Francisco\\nWind Speed in January 2024 in San Francisco\\nHourly Wind Speed in January 2024 in San Francisco\\nHourly Wind Direction in 2024 in San Francisco\\nAtmospheric Pressure in January 2024 in San Francisco\\nData Sources\\n See all nearby weather stations\\nLatest Report — 1:56 PM\\nFri, Jan 12, 2024\\xa0\\xa0\\xa0\\xa04 min ago\\xa0\\xa0\\xa0\\xa0UTC 21:56\\nCall Sign KSFO\\nTemp.\\n54.0°F\\nPrecipitation\\nNo Report\\nWind\\n8.1 mph\\nCloud Cover\\nMostly Cloudy\\n14,000 ft\\nRaw: KSFO 122156Z 08007KT 10SM FEW030 SCT050 BKN140 12/07 A3022 While having the tremendous advantages of temporal and spatial completeness, these reconstructions: (1) are based on computer models that may have model-based errors, (2) are coarsely sampled on a 50 km grid and are therefore unable to reconstruct the local variations of many microclimates, and (3) have particular difficulty with the weather in some coastal areas, especially small islands.\\n We further caution that our travel scores are only as good as the data that underpin them, that weather conditions at any given location and time are unpredictable and variable, and that the definition of the scores reflects a particular set of preferences that may not agree with those of any particular reader.\\n January 2024 Weather History in San Francisco California, United States\\nThe data for this report comes from the San Francisco International Airport.'}]\n" + ] + } + ], + "source": [ + "from langchain_community.tools.tavily_search import TavilySearchResults\n", "\n", - "# Initialize ChatReka\n", - "chat_reka = ChatReka(\n", - " model=\"reka-core\",\n", - " temperature=0.4,\n", - ")\n", - "prompt = hub.pull(\"hwchase17/self-ask-with-search\")\n", - "# Initialize Google Serper API Wrapper\n", - "search = GoogleSerperAPIWrapper()\n", + "search = TavilySearchResults(max_results=2)\n", + "search_results = search.invoke(\"what is the weather in SF\")\n", + "print(search_results)\n", + "# If we want, we can create other tools.\n", + "# Once we have all the tools we want, we can put them in a list that we will reference later.\n", + "tools = [search]\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now see what it is like to enable this model to do tool calling. In order to enable that we use .bind_tools to give the language model knowledge of these tools\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "model_with_tools = model.bind_tools(tools)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now call the model. Let's first call it with a normal message, and see how it responds. We can look at both the content field as well as the tool_calls field.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ContentString: Hello! How can I help you today? If you have a question or need information on a specific topic, feel free to ask. Just type your search query and I'll do my best to assist using the available function.\n", + "\n", + "\n", + "ToolCalls: []\n" + ] + } + ], + "source": [ + "from langchain_core.messages import HumanMessage\n", + "\n", + "response = model_with_tools.invoke([HumanMessage(content=\"Hi!\")])\n", + "\n", + "print(f\"ContentString: {response.content}\")\n", + "print(f\"ToolCalls: {response.tool_calls}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's try calling it with some input that would expect a tool to be called.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ContentString: \n", + "ToolCalls: [{'name': 'tavily_search_results_json', 'args': {'query': 'weather in SF'}, 'id': '2548c622-3553-42df-8220-39fde0632bdb', 'type': 'tool_call'}]\n" + ] + } + ], + "source": [ + "response = model_with_tools.invoke([HumanMessage(content=\"What's the weather in SF?\")])\n", + "\n", + "print(f\"ContentString: {response.content}\")\n", + "print(f\"ToolCalls: {response.tool_calls}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that there's now no text content, but there is a tool call! It wants us to call the Tavily Search tool.\n", + "\n", + "This isn't calling that tool yet - it's just telling us to. In order to actually call it, we'll want to create our agent." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Create the agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have defined the tools and the LLM, we can create the agent. We will be using LangGraph to construct the agent. Currently, we are using a high level interface to construct the agent, but the nice thing about LangGraph is that this high-level interface is backed by a low-level, highly controllable API in case you want to modify the agent logic.\n", "\n", - "# Define tools\n", - "tools = [\n", - " Tool(\n", - " name=\"Intermediate Answer\",\n", - " func=search.run,\n", - " description=\"\"\"\n", - " useful for when you need to ask with search. \n", - " \"\"\",\n", - " )\n", - "]\n", + "Now, we can initialize the agent with the LLM and the tools.\n", "\n", - "# Initialize the agent\n", - "agent = create_self_ask_with_search_agent(chat_reka, tools, prompt)\n", + "Note that we are passing in the model, not model_with_tools. That is because `create_react_agent` will call `.bind_tools` for us under the hood." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from langgraph.prebuilt import create_react_agent\n", "\n", - "agent_executor = AgentExecutor(agent=agent, tools=tools)\n", + "agent_executor = create_react_agent(model, tools)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's now try it out on an example where it should be invoking the tool" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='hi!', additional_kwargs={}, response_metadata={}, id='0ab1f3c7-9079-42d4-8a8a-13af5f6c226b'),\n", + " AIMessage(content=' Hello! How can I help you today? If you have a question or need information on a specific topic, feel free to ask. For example, you can start with a search query like \"latest news on climate change\" or \"biography of Albert Einstein\".\\n\\n', additional_kwargs={}, response_metadata={}, id='run-276d9dcd-13f3-481d-b562-8fe3962d9ba1-0')]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response = agent_executor.invoke({\"messages\": [HumanMessage(content=\"hi!\")]})\n", "\n", + "response[\"messages\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to see exactly what is happening under the hood (and to make sure it's not calling a tool) we can take a look at the LangSmith trace: https://smith.langchain.com/public/2372d9c5-855a-45ee-80f2-94b63493563d/r" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='whats the weather in sf?', additional_kwargs={}, response_metadata={}, id='af276c61-3df7-4241-8cb0-81d1f1477bb3'),\n", + " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': '86da84b8-0d44-444f-8448-7f134f9afa41', 'type': 'function', 'function': {'name': 'tavily_search_results_json', 'arguments': '{\"query\": \"weather in SF\"}'}}]}, response_metadata={}, id='run-abe1b8e2-98a6-4f69-8f95-278ac8c141ff-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'weather in SF'}, 'id': '86da84b8-0d44-444f-8448-7f134f9afa41', 'type': 'tool_call'}]),\n", + " ToolMessage(content='[{\"url\": \"https://www.weatherapi.com/\", \"content\": \"{\\'location\\': {\\'name\\': \\'San Francisco\\', \\'region\\': \\'California\\', \\'country\\': \\'United States of America\\', \\'lat\\': 37.775, \\'lon\\': -122.4183, \\'tz_id\\': \\'America/Los_Angeles\\', \\'localtime_epoch\\': 1730483436, \\'localtime\\': \\'2024-11-01 10:50\\'}, \\'current\\': {\\'last_updated_epoch\\': 1730483100, \\'last_updated\\': \\'2024-11-01 10:45\\', \\'temp_c\\': 11.4, \\'temp_f\\': 52.5, \\'is_day\\': 1, \\'condition\\': {\\'text\\': \\'Mist\\', \\'icon\\': \\'//cdn.weatherapi.com/weather/64x64/day/143.png\\', \\'code\\': 1030}, \\'wind_mph\\': 2.2, \\'wind_kph\\': 3.6, \\'wind_degree\\': 237, \\'wind_dir\\': \\'WSW\\', \\'pressure_mb\\': 1019.0, \\'pressure_in\\': 30.08, \\'precip_mm\\': 0.0, \\'precip_in\\': 0.0, \\'humidity\\': 100, \\'cloud\\': 100, \\'feelslike_c\\': 11.8, \\'feelslike_f\\': 53.2, \\'windchill_c\\': 11.2, \\'windchill_f\\': 52.1, \\'heatindex_c\\': 11.7, \\'heatindex_f\\': 53.0, \\'dewpoint_c\\': 10.1, \\'dewpoint_f\\': 50.1, \\'vis_km\\': 2.8, \\'vis_miles\\': 1.0, \\'uv\\': 3.0, \\'gust_mph\\': 3.0, \\'gust_kph\\': 4.9}}\"}, {\"url\": \"https://www.timeanddate.com/weather/@z-us-94134/ext\", \"content\": \"Forecasted weather conditions the coming 2 weeks for San Francisco. Sign in. News. News Home; Astronomy News; Time Zone News ... 01 pm: Mon Nov 11: 60 / 53 °F: Tstorms early. Broken clouds. 54 °F: 19 mph: ↑: 70%: 58%: 0.20\\\\\" 0 (Low) 6:46 am: 5:00 pm * Updated Monday, October 28, 2024 2:24:10 pm San Francisco time - Weather by CustomWeather\"}]', name='tavily_search_results_json', id='de8c8d78-ae24-4a8a-9c73-795c1e4fdd41', tool_call_id='86da84b8-0d44-444f-8448-7f134f9afa41', artifact={'query': 'weather in SF', 'follow_up_questions': None, 'answer': None, 'images': [], 'results': [{'title': 'Weather in San Francisco', 'url': 'https://www.weatherapi.com/', 'content': \"{'location': {'name': 'San Francisco', 'region': 'California', 'country': 'United States of America', 'lat': 37.775, 'lon': -122.4183, 'tz_id': 'America/Los_Angeles', 'localtime_epoch': 1730483436, 'localtime': '2024-11-01 10:50'}, 'current': {'last_updated_epoch': 1730483100, 'last_updated': '2024-11-01 10:45', 'temp_c': 11.4, 'temp_f': 52.5, 'is_day': 1, 'condition': {'text': 'Mist', 'icon': '//cdn.weatherapi.com/weather/64x64/day/143.png', 'code': 1030}, 'wind_mph': 2.2, 'wind_kph': 3.6, 'wind_degree': 237, 'wind_dir': 'WSW', 'pressure_mb': 1019.0, 'pressure_in': 30.08, 'precip_mm': 0.0, 'precip_in': 0.0, 'humidity': 100, 'cloud': 100, 'feelslike_c': 11.8, 'feelslike_f': 53.2, 'windchill_c': 11.2, 'windchill_f': 52.1, 'heatindex_c': 11.7, 'heatindex_f': 53.0, 'dewpoint_c': 10.1, 'dewpoint_f': 50.1, 'vis_km': 2.8, 'vis_miles': 1.0, 'uv': 3.0, 'gust_mph': 3.0, 'gust_kph': 4.9}}\", 'score': 0.9989501, 'raw_content': None}, {'title': 'San Francisco, USA 14 day weather forecast - timeanddate.com', 'url': 'https://www.timeanddate.com/weather/@z-us-94134/ext', 'content': 'Forecasted weather conditions the coming 2 weeks for San Francisco. Sign in. News. News Home; Astronomy News; Time Zone News ... 01 pm: Mon Nov 11: 60 / 53 °F: Tstorms early. Broken clouds. 54 °F: 19 mph: ↑: 70%: 58%: 0.20\" 0 (Low) 6:46 am: 5:00 pm * Updated Monday, October 28, 2024 2:24:10 pm San Francisco time - Weather by CustomWeather', 'score': 0.9938309, 'raw_content': None}], 'response_time': 3.56}),\n", + " AIMessage(content=' The current weather in San Francisco is mist with a temperature of 11.4°C (52.5°F). There is a 100% humidity and the wind is blowing at 2.2 mph from the WSW direction. The forecast for the coming weeks shows a mix of cloudy and partly cloudy days with some chances of thunderstorms. Temperatures are expected to range between 53°F and 60°F.\\n\\n', additional_kwargs={}, response_metadata={}, id='run-de4207d6-e8e8-4382-ad16-4de0dcf0812a-0')]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response = agent_executor.invoke(\n", + " {\"messages\": [HumanMessage(content=\"whats the weather in sf?\")]}\n", + ")\n", + "response[\"messages\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can check out the LangSmith trace to make sure it's calling the search tool effectively.\n", "\n", - "# Example usage\n", - "if __name__ == \"__main__\":\n", - " query = \"What is the hometown of the reigning men's U.S. Open champion?\"\n", - " agent_executor.invoke({\"input\": \"query\"})" + "https://smith.langchain.com/public/013ef704-654b-4447-8428-637b343d646e/r" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We've seen how the agent can be called with `.invoke` to get a final response. If the agent executes multiple steps, this may take a while. To show intermediate progress, we can stream back messages as they occur.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': '2457d3ea-f001-4b8c-a1ed-3dc3d1381639', 'type': 'function', 'function': {'name': 'tavily_search_results_json', 'arguments': '{\"query\": \"weather in San Francisco\"}'}}]}, response_metadata={}, id='run-0363deab-84d2-4319-bb1e-b55b47fe2274-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'weather in San Francisco'}, 'id': '2457d3ea-f001-4b8c-a1ed-3dc3d1381639', 'type': 'tool_call'}])]}}\n", + "----\n", + "{'tools': {'messages': [ToolMessage(content='[{\"url\": \"https://www.weatherapi.com/\", \"content\": \"{\\'location\\': {\\'name\\': \\'San Francisco\\', \\'region\\': \\'California\\', \\'country\\': \\'United States of America\\', \\'lat\\': 37.775, \\'lon\\': -122.4183, \\'tz_id\\': \\'America/Los_Angeles\\', \\'localtime_epoch\\': 1730483636, \\'localtime\\': \\'2024-11-01 10:53\\'}, \\'current\\': {\\'last_updated_epoch\\': 1730483100, \\'last_updated\\': \\'2024-11-01 10:45\\', \\'temp_c\\': 11.4, \\'temp_f\\': 52.5, \\'is_day\\': 1, \\'condition\\': {\\'text\\': \\'Mist\\', \\'icon\\': \\'//cdn.weatherapi.com/weather/64x64/day/143.png\\', \\'code\\': 1030}, \\'wind_mph\\': 2.2, \\'wind_kph\\': 3.6, \\'wind_degree\\': 237, \\'wind_dir\\': \\'WSW\\', \\'pressure_mb\\': 1019.0, \\'pressure_in\\': 30.08, \\'precip_mm\\': 0.0, \\'precip_in\\': 0.0, \\'humidity\\': 100, \\'cloud\\': 100, \\'feelslike_c\\': 11.8, \\'feelslike_f\\': 53.2, \\'windchill_c\\': 11.2, \\'windchill_f\\': 52.1, \\'heatindex_c\\': 11.7, \\'heatindex_f\\': 53.0, \\'dewpoint_c\\': 10.1, \\'dewpoint_f\\': 50.1, \\'vis_km\\': 2.8, \\'vis_miles\\': 1.0, \\'uv\\': 3.0, \\'gust_mph\\': 3.0, \\'gust_kph\\': 4.9}}\"}, {\"url\": \"https://weather.com/weather/monthly/l/69bedc6a5b6e977993fb3e5344e3c06d8bc36a1fb6754c3ddfb5310a3c6d6c87\", \"content\": \"Weather.com brings you the most accurate monthly weather forecast for San Francisco, CA with average/record and high/low temperatures, precipitation and more. ... 11. 66 ° 55 ° 12. 69 ° 60\"}]', name='tavily_search_results_json', id='e675f99b-130f-4e98-8477-badd45938d9d', tool_call_id='2457d3ea-f001-4b8c-a1ed-3dc3d1381639', artifact={'query': 'weather in San Francisco', 'follow_up_questions': None, 'answer': None, 'images': [], 'results': [{'title': 'Weather in San Francisco', 'url': 'https://www.weatherapi.com/', 'content': \"{'location': {'name': 'San Francisco', 'region': 'California', 'country': 'United States of America', 'lat': 37.775, 'lon': -122.4183, 'tz_id': 'America/Los_Angeles', 'localtime_epoch': 1730483636, 'localtime': '2024-11-01 10:53'}, 'current': {'last_updated_epoch': 1730483100, 'last_updated': '2024-11-01 10:45', 'temp_c': 11.4, 'temp_f': 52.5, 'is_day': 1, 'condition': {'text': 'Mist', 'icon': '//cdn.weatherapi.com/weather/64x64/day/143.png', 'code': 1030}, 'wind_mph': 2.2, 'wind_kph': 3.6, 'wind_degree': 237, 'wind_dir': 'WSW', 'pressure_mb': 1019.0, 'pressure_in': 30.08, 'precip_mm': 0.0, 'precip_in': 0.0, 'humidity': 100, 'cloud': 100, 'feelslike_c': 11.8, 'feelslike_f': 53.2, 'windchill_c': 11.2, 'windchill_f': 52.1, 'heatindex_c': 11.7, 'heatindex_f': 53.0, 'dewpoint_c': 10.1, 'dewpoint_f': 50.1, 'vis_km': 2.8, 'vis_miles': 1.0, 'uv': 3.0, 'gust_mph': 3.0, 'gust_kph': 4.9}}\", 'score': 0.9968992, 'raw_content': None}, {'title': 'Monthly Weather Forecast for San Francisco, CA - weather.com', 'url': 'https://weather.com/weather/monthly/l/69bedc6a5b6e977993fb3e5344e3c06d8bc36a1fb6754c3ddfb5310a3c6d6c87', 'content': 'Weather.com brings you the most accurate monthly weather forecast for San Francisco, CA with average/record and high/low temperatures, precipitation and more. ... 11. 66 ° 55 ° 12. 69 ° 60', 'score': 0.97644573, 'raw_content': None}], 'response_time': 3.16})]}}\n", + "----\n", + "{'agent': {'messages': [AIMessage(content=' The current weather in San Francisco is misty with a temperature of 11.4°C (52.5°F). The wind is blowing at 2.2 mph (3.6 kph) from the WSW direction. The humidity is at 100%, and the visibility is 2.8 km (1.0 miles). The monthly forecast shows average temperatures ranging from 55°F to 66°F (13°C to 19°C) with some precipitation expected.\\n\\n', additional_kwargs={}, response_metadata={}, id='run-99ccf444-d286-4244-a5a5-7b1b511153a6-0')]}}\n", + "----\n" + ] + } + ], + "source": [ + "for chunk in agent_executor.stream(\n", + " {\"messages\": [HumanMessage(content=\"whats the weather in sf?\")]}\n", + "):\n", + " print(chunk)\n", + " print(\"----\")" ] } ], From 05c427ce5a7e02c09d1215b8facfd86e84284d49 Mon Sep 17 00:00:00 2001 From: findalexli Date: Fri, 1 Nov 2024 11:35:39 -0700 Subject: [PATCH 22/36] Lint/formatt --- .../langchain_community/chat_models/reka.py | 4 +-- .../chat_models/test_reka.py | 12 +++---- .../tests/unit_tests/chat_models/test_reka.py | 31 +++++++------------ 3 files changed, 19 insertions(+), 28 deletions(-) diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index cb98f6850ccd4..a1a3a1a8c35b6 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -1,5 +1,4 @@ import json -import warnings from typing import ( Any, AsyncIterator, @@ -18,7 +17,7 @@ AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain_core.language_models import BaseLanguageModel, LanguageModelInput +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, @@ -141,6 +140,7 @@ def convert_to_reka_messages(messages: List[BaseMessage]) -> List[Dict[str, Any] return reka_messages + class ChatReka(BaseChatModel): """Reka chat large language models.""" diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index fe4a6acc6f170..02439cd31db75 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -8,8 +8,8 @@ AIMessage, BaseMessage, HumanMessage, - ToolMessage, SystemMessage, + ToolMessage, ) from langchain_core.outputs import ChatGeneration, LLMResult @@ -182,7 +182,7 @@ def test_reka_system_message() -> None: chat = ChatReka(model="reka-flash", verbose=True) messages = [ SystemMessage(content="You are a helpful AI that speaks like Shakespeare."), - HumanMessage(content="Tell me about the weather today.") + HumanMessage(content="Tell me about the weather today."), ] response = chat.invoke(messages) assert isinstance(response, AIMessage) @@ -199,16 +199,16 @@ def test_reka_system_message_multi_turn() -> None: SystemMessage(content="You are a math tutor who explains concepts simply."), HumanMessage(content="What is a prime number?"), ] - + # First turn response1 = chat.invoke(messages) assert isinstance(response1, AIMessage) messages.append(response1) - + # Second turn messages.append(HumanMessage(content="Can you give me an example?")) response2 = chat.invoke(messages) assert isinstance(response2, AIMessage) - + logger.debug(f"First response: {response1.content}") - logger.debug(f"Second response: {response2.content}") \ No newline at end of file + logger.debug(f"Second response: {response2.content}") diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 7f83fa67f27ff..642690e0d117a 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -224,11 +224,8 @@ def test_reka_tool_use_with_mocked_response() -> None: { "role": "user", "content": [ - { - "type": "text", - "text": "You are a helpful assistant.\nHello" - } - ] + {"type": "text", "text": "You are a helpful assistant.\nHello"} + ], } ], ), @@ -246,24 +243,18 @@ def test_reka_tool_use_with_mocked_response() -> None: "content": [ { "type": "text", - "text": "You are a helpful assistant.\nWhat is 2+2?" + "text": "You are a helpful assistant.\nWhat is 2+2?", } - ] - }, - { - "role": "assistant", - "content": [{"type": "text", "text": "4"}] - }, - { - "role": "user", - "content": [{"type": "text", "text": "Thanks!"}] + ], }, + {"role": "assistant", "content": [{"type": "text", "text": "4"}]}, + {"role": "user", "content": [{"type": "text", "text": "Thanks!"}]}, ], ), # Test system message with media content ( [ - SystemMessage(content="You are a helpful assistant."), + SystemMessage(content="Hi."), HumanMessage( content=[ {"type": "text", "text": "What's in this image?"}, @@ -280,13 +271,13 @@ def test_reka_tool_use_with_mocked_response() -> None: "content": [ { "type": "text", - "text": "You are a helpful assistant.\nWhat's in this image?" + "text": "Hi.\nWhat's in this image?", }, { "type": "image_url", "image_url": "https://example.com/image.jpg", }, - ] + ], }, ], ), @@ -308,6 +299,6 @@ def test_multiple_system_messages_error() -> None: SystemMessage(content="System message 2"), HumanMessage(content="Hello"), ] - + with pytest.raises(ValueError, match="Multiple system messages are not supported."): - convert_to_reka_messages(messages) \ No newline at end of file + convert_to_reka_messages(messages) From 3489b995ee745069d116494ff0657658f8f77c67 Mon Sep 17 00:00:00 2001 From: findalexli Date: Tue, 5 Nov 2024 10:14:39 -0800 Subject: [PATCH 23/36] Linted notebook --- docs/docs/integrations/chat/reka.ipynb | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb index 7cee7e9824838..d17e2c0d196be 100644 --- a/docs/docs/integrations/chat/reka.ipynb +++ b/docs/docs/integrations/chat/reka.ipynb @@ -105,7 +105,7 @@ "import os\n", "\n", "os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n", - "os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass('Enter your Langsmith API key: ')" + "os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your Langsmith API key: \")" ] }, { @@ -258,7 +258,7 @@ "import getpass\n", "import os\n", "\n", - "os.environ[\"TAVILY_API_KEY\"] = getpass.getpass('Enter your Tavily API key: ')" + "os.environ[\"TAVILY_API_KEY\"] = getpass.getpass(\"Enter your Tavily API key: \")" ] }, { @@ -282,8 +282,7 @@ "print(search_results)\n", "# If we want, we can create other tools.\n", "# Once we have all the tools we want, we can put them in a list that we will reference later.\n", - "tools = [search]\n", - "\n" + "tools = [search]" ] }, { @@ -300,7 +299,7 @@ "metadata": {}, "outputs": [], "source": [ - "model_with_tools = model.bind_tools(tools)\n" + "model_with_tools = model.bind_tools(tools)" ] }, { From b59fb7f746851c10b9e322a8ed26d79454d92b30 Mon Sep 17 00:00:00 2001 From: findalexli Date: Tue, 5 Nov 2024 10:30:53 -0800 Subject: [PATCH 24/36] add chatke in test import --- libs/community/tests/unit_tests/chat_models/test_imports.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index ee3240168e01e..b8598165d9bf2 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -45,6 +45,7 @@ "ChatVertexAI", "ChatYandexGPT", "ChatYuan2", + "ChatReKa", "ChatZhipuAI", "ErnieBotChat", "FakeListChatModel", From c7d75a2b7f25a755bc3f9688671a4e8ad198eacd Mon Sep 17 00:00:00 2001 From: findalexli Date: Tue, 5 Nov 2024 12:18:09 -0800 Subject: [PATCH 25/36] add tiktoken token count method --- .../langchain_community/chat_models/reka.py | 11 +++++----- .../tests/unit_tests/chat_models/test_reka.py | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index a1a3a1a8c35b6..d8c4b3abb7184 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -13,6 +13,7 @@ Union, ) +import tiktoken from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -153,9 +154,9 @@ class ChatReka(BaseChatModel): default_request_timeout: Optional[float] = None max_retries: int = 2 reka_api_key: Optional[str] = None - count_tokens: Optional[Callable[[str], int]] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_config = ConfigDict(extra="forbid") + _tiktoken_encoder = None @model_validator(mode="before") @classmethod @@ -330,11 +331,9 @@ async def _agenerate( def get_num_tokens(self, text: str) -> int: """Calculate number of tokens.""" - if self.count_tokens is None: - raise NotImplementedError( - "get_num_tokens() is not implemented for Reka models." - ) - return self.count_tokens(text) + if self._tiktoken_encoder is None: + self._tiktoken_encoder = tiktoken.get_encoding("cl100k_base") + return len(self._tiktoken_encoder.encode(text)) def bind_tools( self, diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 642690e0d117a..00818f8bd33c4 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest +import tiktoken from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from pydantic import ValidationError @@ -302,3 +303,22 @@ def test_multiple_system_messages_error() -> None: with pytest.raises(ValueError, match="Multiple system messages are not supported."): convert_to_reka_messages(messages) + + +@pytest.mark.requires("reka") +def test_get_num_tokens() -> None: + """Test that token counting works correctly.""" + llm = ChatReka() + + # Test basic text + text = "Hello, world!" + expected_tokens = len(tiktoken.get_encoding("cl100k_base").encode(text)) + assert llm.get_num_tokens(text) == expected_tokens + + # Test empty string + assert llm.get_num_tokens("") == 0 + + # Test longer text with special characters + complex_text = "Hello 🌍! This is a test of the token counting" + expected_tokens = len(tiktoken.get_encoding("cl100k_base").encode(complex_text)) + assert llm.get_num_tokens(complex_text) == expected_tokens From d6f0ef43d18427bda2676b7f2288bc480b81b8b5 Mon Sep 17 00:00:00 2001 From: findalexli Date: Wed, 6 Nov 2024 14:47:32 -0800 Subject: [PATCH 26/36] Add required header in doc notebook --- docs/docs/integrations/chat/reka.ipynb | 84 ++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 12 deletions(-) diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb index d17e2c0d196be..4016185759819 100644 --- a/docs/docs/integrations/chat/reka.ipynb +++ b/docs/docs/integrations/chat/reka.ipynb @@ -73,7 +73,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Initialize a client" + "## Instantiation" ] }, { @@ -110,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -123,21 +123,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Single turn text message" + "## Invocation" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AIMessage(content=' Hello! How can I help you today? If you have a question, need assistance, or just want to chat, feel free to let me know. Have a great day!\\n\\n', additional_kwargs={}, response_metadata={}, id='run-b40e505a-5110-451a-92e6-a2a34988472c-0')" + "AIMessage(content=' Hello! How can I help you today? If you have a question, need assistance, or just want to chat, feel free to let me know. Have a great day!\\n\\n', additional_kwargs={}, response_metadata={}, id='run-61522ec2-0587-4fd5-a492-5b205fd8860c-0')" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -155,14 +155,14 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " The image shows an indoor setting with no visible weather elements. It features a cat on a desk licking a computer keyboard. The background includes a computer monitor, a desk with a few items like a pen holder and a mobile phone, and a glimpse of a window with blinds partially drawn.\n" + " The image shows an indoor setting with no visible windows or natural light, and there are no indicators of weather conditions. The focus is on a cat sitting on a computer keyboard, and the background includes a computer monitor and various office supplies.\n" ] } ], @@ -193,18 +193,18 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " The first image shows two German Shepherds, one adult and one puppy, in a grassy field. The adult dog is carrying a large stick in its mouth, indicating playfulness or a game being played. The background features a natural, leafy environment, suggesting an outdoor setting conducive to activities like running or training.\n", + " The first image features two German Shepherds, one adult and one puppy, in a vibrant, lush green setting. The adult dog is carrying a large stick in its mouth, running through what appears to be a grassy field, with the puppy following close behind. Both dogs exhibit striking physical characteristics typical of the breed, such as pointed ears and dense fur.\n", "\n", - "The second image features a close-up of a single cat with striking blue eyes, set against a background of dry leaves or grass. The cat has a calm and somewhat intense expression, with its fur neatly groomed and whiskers prominently visible. The focus is on the cat's face, capturing its serene demeanor in a quiet, natural outdoor setting.\n", + "The second image shows a close-up of a single cat with striking blue eyes, likely a breed like the Siberian or Maine Coon, in a natural outdoor setting. The cat's fur is lighter, possibly a mix of white and gray, and it has a more subdued expression compared to the dogs. The background is blurred, suggesting a focus on the cat's face.\n", "\n", - "The main differences lie in the subjects (dogs vs. cat) and their expressions (playful vs. serene), as well as the composition and focus of the images (outdoor play vs. close-up portrait).\n" + "Overall, the differences lie in the subjects (two dogs vs. one cat), the setting (lush, vibrant grassy field vs. a more muted outdoor background), and the overall mood and activity depicted (playful and active vs. serene and focused).\n" ] } ], @@ -230,6 +230,52 @@ "print(response.content)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chaining" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=' Ich liebe Programmieren.\\n\\n', additional_kwargs={}, response_metadata={}, id='run-ffc4ace1-b73a-4fb3-ad0f-57e60a0f9b8d-0')" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "prompt = ChatPromptTemplate(\n", + " [\n", + " (\n", + " \"system\",\n", + " \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n", + " ),\n", + " (\"human\", \"{input}\"),\n", + " ]\n", + ")\n", + "\n", + "chain = prompt | model\n", + "chain.invoke(\n", + " {\n", + " \"input_language\": \"English\",\n", + " \"output_language\": \"German\",\n", + " \"input\": \"I love programming.\",\n", + " }\n", + ")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -507,6 +553,20 @@ " print(chunk)\n", " print(\"----\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API reference" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "https://docs.reka.ai/quick-start" + ] } ], "metadata": { From c72aefb1b11f1d81cb859899f5e5eb60fad169ea Mon Sep 17 00:00:00 2001 From: findalexli Date: Wed, 6 Nov 2024 15:46:46 -0800 Subject: [PATCH 27/36] token count update --- libs/community/extended_testing_deps.txt | 1 + .../langchain_community/chat_models/reka.py | 31 +++++++++++++--- .../unit_tests/chat_models/test_imports.py | 2 +- .../tests/unit_tests/chat_models/test_reka.py | 37 +++++++++++++++---- 4 files changed, 57 insertions(+), 14 deletions(-) diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 8d1deccb323bb..94482d2397e25 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -87,6 +87,7 @@ telethon>=1.28.5,<2 tidb-vector>=0.0.3,<1.0.0 timescale-vector==0.0.1 tqdm>=4.48.0 +tiktoken>=0.8.0 tree-sitter>=0.20.2,<0.21 tree-sitter-languages>=1.8.0,<2 upstash-redis>=1.1.0,<2 diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index d8c4b3abb7184..06b5cce2d900d 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -13,7 +13,6 @@ Union, ) -import tiktoken from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -156,7 +155,9 @@ class ChatReka(BaseChatModel): reka_api_key: Optional[str] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_config = ConfigDict(extra="forbid") - _tiktoken_encoder = None + token_counter: Optional[ + Union[Callable[[list[BaseMessage]], int], Callable[[BaseMessage], int]] + ] = None @model_validator(mode="before") @classmethod @@ -329,11 +330,29 @@ async def _agenerate( return ChatResult(generations=[ChatGeneration(message=message)]) - def get_num_tokens(self, text: str) -> int: + def get_num_tokens(self, input: Union[str, BaseMessage, List[BaseMessage]]) -> int: """Calculate number of tokens.""" - if self._tiktoken_encoder is None: - self._tiktoken_encoder = tiktoken.get_encoding("cl100k_base") - return len(self._tiktoken_encoder.encode(text)) + # Initialize encoder if not already set + + if self.token_counter is None: + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "Please install it with `pip install tiktoken`." + ) + encoding = tiktoken.get_encoding("cl100k_base") + + if isinstance(input, str): + return len(encoding.encode(input)) + elif isinstance(input, BaseMessage): + return len(encoding.encode(input.content)) + elif isinstance(input, list): + return sum(len(encoding.encode(msg.content)) for msg in input) + raise ValueError(f"Got unexpected type for input: {type(input)}") + + return self.token_counter(input) def bind_tools( self, diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index b8598165d9bf2..4022fe781bf02 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -45,7 +45,7 @@ "ChatVertexAI", "ChatYandexGPT", "ChatYuan2", - "ChatReKa", + "ChatReka", "ChatZhipuAI", "ErnieBotChat", "FakeListChatModel", diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 00818f8bd33c4..01c402e26919a 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -4,7 +4,6 @@ from unittest.mock import MagicMock, patch import pytest -import tiktoken from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from pydantic import ValidationError @@ -305,20 +304,44 @@ def test_multiple_system_messages_error() -> None: convert_to_reka_messages(messages) +@pytest.mark.requires("tiktoken") @pytest.mark.requires("reka") def test_get_num_tokens() -> None: - """Test that token counting works correctly.""" + """Test that token counting works correctly for different input types.""" llm = ChatReka() + import tiktoken - # Test basic text + encoding = tiktoken.get_encoding("cl100k_base") + + # Test string input text = "Hello, world!" - expected_tokens = len(tiktoken.get_encoding("cl100k_base").encode(text)) + expected_tokens = len(encoding.encode(text)) assert llm.get_num_tokens(text) == expected_tokens - # Test empty string + # Test BaseMessage input + message = HumanMessage(content="What is the weather like today?") + expected_tokens = len(encoding.encode(message.content)) + assert llm.get_num_tokens(message) == expected_tokens + + # Test List[BaseMessage] input + messages = [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Hi!"), + AIMessage(content="Hello! How can I help you today?"), + ] + expected_tokens = sum(len(encoding.encode(msg.content)) for msg in messages) + assert llm.get_num_tokens(messages) == expected_tokens + + # Test empty inputs assert llm.get_num_tokens("") == 0 + assert llm.get_num_tokens(HumanMessage(content="")) == 0 + assert llm.get_num_tokens([]) == 0 - # Test longer text with special characters + # Test complex text with special characters complex_text = "Hello 🌍! This is a test of the token counting" - expected_tokens = len(tiktoken.get_encoding("cl100k_base").encode(complex_text)) + expected_tokens = len(encoding.encode(complex_text)) assert llm.get_num_tokens(complex_text) == expected_tokens + + # Test invalid input type + with pytest.raises(ValueError, match="Got unexpected type for input:"): + llm.get_num_tokens(123) # type: ignore From a229d20da02af3a679ff3683d3eb5101941e5efb Mon Sep 17 00:00:00 2001 From: findalexli Date: Thu, 7 Nov 2024 18:02:11 -0800 Subject: [PATCH 28/36] Type consistency for token counter --- .../langchain_community/chat_models/reka.py | 66 +++++++++++++------ .../tests/unit_tests/chat_models/test_reka.py | 22 +++---- 2 files changed, 53 insertions(+), 35 deletions(-) diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index 06b5cce2d900d..f56001f37b160 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -156,7 +156,7 @@ class ChatReka(BaseChatModel): model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_config = ConfigDict(extra="forbid") token_counter: Optional[ - Union[Callable[[list[BaseMessage]], int], Callable[[BaseMessage], int]] + Callable[[Union[str, BaseMessage, List[BaseMessage]]], int] ] = None @model_validator(mode="before") @@ -331,28 +331,52 @@ async def _agenerate( return ChatResult(generations=[ChatGeneration(message=message)]) def get_num_tokens(self, input: Union[str, BaseMessage, List[BaseMessage]]) -> int: - """Calculate number of tokens.""" - # Initialize encoder if not already set - - if self.token_counter is None: - try: - import tiktoken - except ImportError: - raise ImportError( - "Could not import tiktoken python package. " - "Please install it with `pip install tiktoken`." - ) - encoding = tiktoken.get_encoding("cl100k_base") + """Calculate number of tokens. + + Args: + input: Either a string, a single BaseMessage, or a list of BaseMessages. + + Returns: + int: Number of tokens in the input. - if isinstance(input, str): - return len(encoding.encode(input)) - elif isinstance(input, BaseMessage): - return len(encoding.encode(input.content)) - elif isinstance(input, list): - return sum(len(encoding.encode(msg.content)) for msg in input) - raise ValueError(f"Got unexpected type for input: {type(input)}") + Raises: + ImportError: If tiktoken is not installed. + ValueError: If message content is not a string. + """ + if self.token_counter is not None: + return self.token_counter(input) - return self.token_counter(input) + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "Please install it with `pip install tiktoken`." + ) + + encoding = tiktoken.get_encoding("cl100k_base") + + if isinstance(input, str): + return len(encoding.encode(input)) + elif isinstance(input, BaseMessage): + content = input.content + if not isinstance(content, str): + raise ValueError( + f"Message content must be a string, got {type(content)}" + ) + return len(encoding.encode(content)) + elif isinstance(input, list): + total = 0 + for msg in input: + content = msg.content + if not isinstance(content, str): + raise ValueError( + f"Message content must be a string, got {type(content)}" + ) + total += len(encoding.encode(content)) + return total + else: + raise TypeError(f"Unsupported input type: {type(input)}") def bind_tools( self, diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 01c402e26919a..f5afbdf3f201b 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -314,12 +314,13 @@ def test_get_num_tokens() -> None: encoding = tiktoken.get_encoding("cl100k_base") # Test string input - text = "Hello, world!" + text = "What is the weather like today?" expected_tokens = len(encoding.encode(text)) assert llm.get_num_tokens(text) == expected_tokens # Test BaseMessage input message = HumanMessage(content="What is the weather like today?") + assert isinstance(message.content, str) expected_tokens = len(encoding.encode(message.content)) assert llm.get_num_tokens(message) == expected_tokens @@ -329,19 +330,12 @@ def test_get_num_tokens() -> None: HumanMessage(content="Hi!"), AIMessage(content="Hello! How can I help you today?"), ] - expected_tokens = sum(len(encoding.encode(msg.content)) for msg in messages) + expected_tokens = sum( + len(encoding.encode(msg.content)) + for msg in messages + if isinstance(msg.content, str) + ) assert llm.get_num_tokens(messages) == expected_tokens - # Test empty inputs - assert llm.get_num_tokens("") == 0 - assert llm.get_num_tokens(HumanMessage(content="")) == 0 + # Test empty message list assert llm.get_num_tokens([]) == 0 - - # Test complex text with special characters - complex_text = "Hello 🌍! This is a test of the token counting" - expected_tokens = len(encoding.encode(complex_text)) - assert llm.get_num_tokens(complex_text) == expected_tokens - - # Test invalid input type - with pytest.raises(ValueError, match="Got unexpected type for input:"): - llm.get_num_tokens(123) # type: ignore From 88a24fe8de502dddb7816f5735494cf7133b5db5 Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 15 Nov 2024 07:59:49 -0500 Subject: [PATCH 29/36] remove from deps --- libs/community/extended_testing_deps.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 94482d2397e25..d331fb66e85dd 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -74,7 +74,6 @@ rapidfuzz>=3.1.1,<4 rapidocr-onnxruntime>=1.3.2,<2 rdflib==7.0.0 requests-toolbelt>=1.0.0,<2 -reka-api>=3.2.0,<4 rspace_client>=2.5.0,<3 scikit-learn>=1.2.2,<2 simsimd>=5.0.0,<6 From d1261319c09dc958827a277ddb11a960e43bf2aa Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 15 Nov 2024 08:07:15 -0500 Subject: [PATCH 30/36] remove scheduled --- .../tests/integration_tests/chat_models/test_reka.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index 02439cd31db75..5f76f5a7f49ae 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -21,7 +21,6 @@ @pytest.mark.requires("reka") -@pytest.mark.scheduled def test_reka_call() -> None: """Test a simple call to Reka.""" chat = ChatReka(model="reka-flash", verbose=True) @@ -33,7 +32,6 @@ def test_reka_call() -> None: @pytest.mark.requires("reka") -@pytest.mark.scheduled def test_reka_generate() -> None: """Test the generate method of Reka.""" chat = ChatReka(model="reka-flash", verbose=True) @@ -52,7 +50,6 @@ def test_reka_generate() -> None: @pytest.mark.requires("reka") -@pytest.mark.scheduled def test_reka_streaming() -> None: """Test streaming tokens from Reka.""" chat = ChatReka(model="reka-flash", streaming=True, verbose=True) @@ -64,7 +61,6 @@ def test_reka_streaming() -> None: @pytest.mark.requires("reka") -@pytest.mark.scheduled def test_reka_streaming_callback() -> None: """Test that streaming correctly invokes callbacks.""" callback_handler = FakeCallbackHandler() @@ -81,7 +77,6 @@ def test_reka_streaming_callback() -> None: @pytest.mark.requires("reka") -@pytest.mark.scheduled async def test_reka_async_streaming_callback() -> None: """Test asynchronous streaming with callbacks.""" callback_handler = FakeCallbackHandler() @@ -105,7 +100,6 @@ async def test_reka_async_streaming_callback() -> None: @pytest.mark.requires("reka") -@pytest.mark.scheduled def test_reka_tool_usage_integration() -> None: """Test tool usage with Reka API integration.""" # Initialize the ChatReka model with tools and verbose logging @@ -176,7 +170,6 @@ def test_reka_tool_usage_integration() -> None: @pytest.mark.requires("reka") -@pytest.mark.scheduled def test_reka_system_message() -> None: """Test Reka with system message.""" chat = ChatReka(model="reka-flash", verbose=True) @@ -191,7 +184,6 @@ def test_reka_system_message() -> None: @pytest.mark.requires("reka") -@pytest.mark.scheduled def test_reka_system_message_multi_turn() -> None: """Test multi-turn conversation with system message.""" chat = ChatReka(model="reka-flash", verbose=True) From b395905f3f7a2554443a7e07daaaf65454405005 Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 15 Nov 2024 08:24:59 -0500 Subject: [PATCH 31/36] skip --- .../tests/integration_tests/chat_models/test_reka.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index 5f76f5a7f49ae..f245bfe5c5bfa 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -21,6 +21,7 @@ @pytest.mark.requires("reka") +@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") def test_reka_call() -> None: """Test a simple call to Reka.""" chat = ChatReka(model="reka-flash", verbose=True) @@ -32,6 +33,7 @@ def test_reka_call() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") def test_reka_generate() -> None: """Test the generate method of Reka.""" chat = ChatReka(model="reka-flash", verbose=True) @@ -50,6 +52,7 @@ def test_reka_generate() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") def test_reka_streaming() -> None: """Test streaming tokens from Reka.""" chat = ChatReka(model="reka-flash", streaming=True, verbose=True) @@ -61,6 +64,7 @@ def test_reka_streaming() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") def test_reka_streaming_callback() -> None: """Test that streaming correctly invokes callbacks.""" callback_handler = FakeCallbackHandler() @@ -77,6 +81,7 @@ def test_reka_streaming_callback() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") async def test_reka_async_streaming_callback() -> None: """Test asynchronous streaming with callbacks.""" callback_handler = FakeCallbackHandler() @@ -100,6 +105,7 @@ async def test_reka_async_streaming_callback() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") def test_reka_tool_usage_integration() -> None: """Test tool usage with Reka API integration.""" # Initialize the ChatReka model with tools and verbose logging @@ -170,6 +176,7 @@ def test_reka_tool_usage_integration() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") def test_reka_system_message() -> None: """Test Reka with system message.""" chat = ChatReka(model="reka-flash", verbose=True) @@ -184,6 +191,7 @@ def test_reka_system_message() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") def test_reka_system_message_multi_turn() -> None: """Test multi-turn conversation with system message.""" chat = ChatReka(model="reka-flash", verbose=True) From 21e62f4864db05f921376030605ddc330af1cacf Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 15 Nov 2024 08:27:04 -0500 Subject: [PATCH 32/36] lint --- .../chat_models/test_reka.py | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index f245bfe5c5bfa..76182957528ae 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -21,7 +21,9 @@ @pytest.mark.requires("reka") -@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_call() -> None: """Test a simple call to Reka.""" chat = ChatReka(model="reka-flash", verbose=True) @@ -33,7 +35,9 @@ def test_reka_call() -> None: @pytest.mark.requires("reka") -@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_generate() -> None: """Test the generate method of Reka.""" chat = ChatReka(model="reka-flash", verbose=True) @@ -52,7 +56,9 @@ def test_reka_generate() -> None: @pytest.mark.requires("reka") -@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_streaming() -> None: """Test streaming tokens from Reka.""" chat = ChatReka(model="reka-flash", streaming=True, verbose=True) @@ -64,7 +70,9 @@ def test_reka_streaming() -> None: @pytest.mark.requires("reka") -@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_streaming_callback() -> None: """Test that streaming correctly invokes callbacks.""" callback_handler = FakeCallbackHandler() @@ -81,7 +89,9 @@ def test_reka_streaming_callback() -> None: @pytest.mark.requires("reka") -@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) async def test_reka_async_streaming_callback() -> None: """Test asynchronous streaming with callbacks.""" callback_handler = FakeCallbackHandler() @@ -105,7 +115,9 @@ async def test_reka_async_streaming_callback() -> None: @pytest.mark.requires("reka") -@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_tool_usage_integration() -> None: """Test tool usage with Reka API integration.""" # Initialize the ChatReka model with tools and verbose logging @@ -176,7 +188,9 @@ def test_reka_tool_usage_integration() -> None: @pytest.mark.requires("reka") -@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_system_message() -> None: """Test Reka with system message.""" chat = ChatReka(model="reka-flash", verbose=True) @@ -191,7 +205,9 @@ def test_reka_system_message() -> None: @pytest.mark.requires("reka") -@pytest.mark.skip(reason="Dependency conflict w/ other dependencies for urllib3 versions.") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_system_message_multi_turn() -> None: """Test multi-turn conversation with system message.""" chat = ChatReka(model="reka-flash", verbose=True) From b3ea9deb4c0d91c3ad0cd547779bb694c1bba6c0 Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 15 Nov 2024 08:41:09 -0500 Subject: [PATCH 33/36] skip --- .../tests/unit_tests/chat_models/test_reka.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index f5afbdf3f201b..5bb25497823a4 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -17,18 +17,27 @@ @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_model_param() -> None: llm = ChatReka(model="reka-flash") assert llm.model == "reka-flash" @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_model_kwargs() -> None: llm = ChatReka(model_kwargs={"foo": "bar"}) assert llm.model_kwargs == {"foo": "bar"} @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_incorrect_field() -> None: """Test that providing an incorrect field raises ValidationError.""" with pytest.raises(ValidationError): @@ -36,6 +45,9 @@ def test_reka_incorrect_field() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_initialization() -> None: """Test Reka initialization.""" # Verify that ChatReka can be initialized using a secret key provided @@ -44,6 +56,9 @@ def test_reka_initialization() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) @pytest.mark.parametrize( ("content", "expected"), [ @@ -82,6 +97,9 @@ def test_process_content(content: Any, expected: List[Dict[str, Any]]) -> None: @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) @pytest.mark.parametrize( ("messages", "expected"), [ @@ -131,24 +149,36 @@ def test_convert_to_reka_messages( @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_streaming() -> None: llm = ChatReka(streaming=True) assert llm.streaming is True @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_temperature() -> None: llm = ChatReka(temperature=0.5) assert llm.temperature == 0.5 @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_max_tokens() -> None: llm = ChatReka(max_tokens=100) assert llm.max_tokens == 100 @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_default_params() -> None: llm = ChatReka() assert llm._default_params == { @@ -158,6 +188,9 @@ def test_reka_default_params() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_identifying_params() -> None: """Test that ChatReka identifies its default parameters correctly.""" chat = ChatReka(model="reka-flash", temperature=0.7, max_tokens=256) @@ -170,12 +203,18 @@ def test_reka_identifying_params() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_llm_type() -> None: llm = ChatReka() assert llm._llm_type == "reka-chat" @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_reka_tool_use_with_mocked_response() -> None: with patch("reka.client.Reka") as MockReka: # Mock the Reka client @@ -211,6 +250,9 @@ def test_reka_tool_use_with_mocked_response() -> None: @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) @pytest.mark.parametrize( ("messages", "expected"), [ @@ -292,6 +334,9 @@ def test_system_message_handling( @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_multiple_system_messages_error() -> None: """Test that multiple system messages raise an error.""" messages = [ @@ -306,6 +351,9 @@ def test_multiple_system_messages_error() -> None: @pytest.mark.requires("tiktoken") @pytest.mark.requires("reka") +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) def test_get_num_tokens() -> None: """Test that token counting works correctly for different input types.""" llm = ChatReka() From 6874313dc6ce5e676564e47d149911199c29160f Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 15 Nov 2024 08:43:55 -0500 Subject: [PATCH 34/36] update --- .../integration_tests/chat_models/test_reka.py | 8 -------- .../tests/unit_tests/chat_models/test_reka.py | 17 ----------------- 2 files changed, 25 deletions(-) diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py index 76182957528ae..848a0f04bcf87 100644 --- a/libs/community/tests/integration_tests/chat_models/test_reka.py +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -20,7 +20,6 @@ logger = logging.getLogger(__name__) -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -34,7 +33,6 @@ def test_reka_call() -> None: logger.debug(f"Response content: {response.content}") -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -55,7 +53,6 @@ def test_reka_generate() -> None: assert chat_messages == messages_copy -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -69,7 +66,6 @@ def test_reka_streaming() -> None: logger.debug(f"Streaming response content: {response.content}") -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -88,7 +84,6 @@ def test_reka_streaming_callback() -> None: logger.debug(f"Number of LLM streams: {callback_handler.llm_streams}") -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -114,7 +109,6 @@ async def test_reka_async_streaming_callback() -> None: logger.debug(f"Async generated response: {response.text}") -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -187,7 +181,6 @@ def test_reka_tool_usage_integration() -> None: pytest.fail("The model did not request a tool.") -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -204,7 +197,6 @@ def test_reka_system_message() -> None: logger.debug(f"Response with system message: {response.content}") -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 5bb25497823a4..bbacadf7fd926 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -16,7 +16,6 @@ os.environ["REKA_API_KEY"] = "dummy_key" -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -25,7 +24,6 @@ def test_reka_model_param() -> None: assert llm.model == "reka-flash" -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -34,7 +32,6 @@ def test_reka_model_kwargs() -> None: assert llm.model_kwargs == {"foo": "bar"} -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -44,7 +41,6 @@ def test_reka_incorrect_field() -> None: ChatReka(unknown_field="bar") # type: ignore -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -55,7 +51,6 @@ def test_reka_initialization() -> None: ChatReka(model="reka-flash", reka_api_key="test_key") -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -96,7 +91,6 @@ def test_process_content(content: Any, expected: List[Dict[str, Any]]) -> None: assert result == expected -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -148,7 +142,6 @@ def test_convert_to_reka_messages( assert result == expected -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -157,7 +150,6 @@ def test_reka_streaming() -> None: assert llm.streaming is True -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -166,7 +158,6 @@ def test_reka_temperature() -> None: assert llm.temperature == 0.5 -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -175,7 +166,6 @@ def test_reka_max_tokens() -> None: assert llm.max_tokens == 100 -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -187,7 +177,6 @@ def test_reka_default_params() -> None: } -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -202,7 +191,6 @@ def test_reka_identifying_params() -> None: assert chat._default_params == expected_params -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -211,7 +199,6 @@ def test_reka_llm_type() -> None: assert llm._llm_type == "reka-chat" -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -249,7 +236,6 @@ def test_reka_tool_use_with_mocked_response() -> None: ) -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -333,7 +319,6 @@ def test_system_message_handling( assert result == expected -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) @@ -349,8 +334,6 @@ def test_multiple_system_messages_error() -> None: convert_to_reka_messages(messages) -@pytest.mark.requires("tiktoken") -@pytest.mark.requires("reka") @pytest.mark.skip( reason="Dependency conflict w/ other dependencies for urllib3 versions." ) From 4eab2089306f2a1c4853de82be726c7ebd08159e Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 15 Nov 2024 09:04:47 -0500 Subject: [PATCH 35/36] fix link --- docs/docs/integrations/chat/reka.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb index 4016185759819..2e20a8a887ee6 100644 --- a/docs/docs/integrations/chat/reka.ipynb +++ b/docs/docs/integrations/chat/reka.ipynb @@ -15,7 +15,7 @@ "source": [ "# ChatReka\n", "\n", - "This notebook provides a quick overview for getting started with Reka [chat models](/docs/concepts/#chat-models). \n", + "This notebook provides a quick overview for getting started with Reka [chat models](/docs/concepts/chat-models). \n", "\n", "Reka has several chat models. You can find information about their latest models and their costs, context windows, and supported input types in the [Reka docs](https://docs.reka.ai/available-models).\n", "\n", @@ -27,7 +27,7 @@ "\n", "| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n", "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", - "| [ChatReka] | [langchain_community](https://python.langchain.com/v0.2/api_reference/community/index.html) | ✅ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_community?style=flat-square&label=%20) |\n", + "| [ChatReka] | [langchain_community](https://python.langchain.com/api_reference/community/index.html) | ✅ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_community?style=flat-square&label=%20) |\n", "\n", "### Model features\n", "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", From e1e79cfed5a2e43f596a826c56b63d82851a8779 Mon Sep 17 00:00:00 2001 From: vbarda Date: Fri, 15 Nov 2024 12:54:53 -0500 Subject: [PATCH 36/36] update concepts link --- docs/docs/integrations/chat/reka.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb index 2e20a8a887ee6..1ebedb66979d1 100644 --- a/docs/docs/integrations/chat/reka.ipynb +++ b/docs/docs/integrations/chat/reka.ipynb @@ -15,7 +15,7 @@ "source": [ "# ChatReka\n", "\n", - "This notebook provides a quick overview for getting started with Reka [chat models](/docs/concepts/chat-models). \n", + "This notebook provides a quick overview for getting started with Reka [chat models](../../concepts/chat_models.mdx). \n", "\n", "Reka has several chat models. You can find information about their latest models and their costs, context windows, and supported input types in the [Reka docs](https://docs.reka.ai/available-models).\n", "\n",