diff --git a/docs/docs/integrations/chat/reka.ipynb b/docs/docs/integrations/chat/reka.ipynb new file mode 100644 index 0000000000000..1ebedb66979d1 --- /dev/null +++ b/docs/docs/integrations/chat/reka.ipynb @@ -0,0 +1,593 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "sidebar_label: Reka\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ChatReka\n", + "\n", + "This notebook provides a quick overview for getting started with Reka [chat models](../../concepts/chat_models.mdx). \n", + "\n", + "Reka has several chat models. You can find information about their latest models and their costs, context windows, and supported input types in the [Reka docs](https://docs.reka.ai/available-models).\n", + "\n", + "\n", + "\n", + "\n", + "## Overview\n", + "### Integration details\n", + "\n", + "| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n", + "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", + "| [ChatReka] | [langchain_community](https://python.langchain.com/api_reference/community/index.html) | ✅ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_community?style=flat-square&label=%20) |\n", + "\n", + "### Model features\n", + "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", + "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", + "| ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | \n", + "\n", + "## Setup\n", + "\n", + "To access Reka models you'll need to create an Reka developer account, get an API key, and install the `langchain_community` integration package and the reka python package via 'pip install reka-api'.\n", + "\n", + "### Credentials\n", + "\n", + "Head to https://platform.reka.ai/ to sign up for Reka and generate an API key. Once you've done this set the REKA_API_KEY environment variable:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "The LangChain __ModuleName__ integration lives in the `langchain_community` package:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install -qU langchain_community reka-api" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"REKA_API_KEY\"] = getpass.getpass(\"Enter your Reka API key: \")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Optional: use Langsmith to trace the execution of the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n", + "os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your Langsmith API key: \")" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models import ChatReka\n", + "\n", + "model = ChatReka()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Invocation" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=' Hello! How can I help you today? If you have a question, need assistance, or just want to chat, feel free to let me know. Have a great day!\\n\\n', additional_kwargs={}, response_metadata={}, id='run-61522ec2-0587-4fd5-a492-5b205fd8860c-0')" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.invoke(\"hi\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Images input " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " The image shows an indoor setting with no visible windows or natural light, and there are no indicators of weather conditions. The focus is on a cat sitting on a computer keyboard, and the background includes a computer monitor and various office supplies.\n" + ] + } + ], + "source": [ + "from langchain_core.messages import HumanMessage\n", + "\n", + "image_url = \"https://v0.docs.reka.ai/_images/000000245576.jpg\"\n", + "\n", + "message = HumanMessage(\n", + " content=[\n", + " {\"type\": \"text\", \"text\": \"describe the weather in this image\"},\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\"url\": image_url},\n", + " },\n", + " ],\n", + ")\n", + "response = model.invoke([message])\n", + "print(response.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multiple images as input" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " The first image features two German Shepherds, one adult and one puppy, in a vibrant, lush green setting. The adult dog is carrying a large stick in its mouth, running through what appears to be a grassy field, with the puppy following close behind. Both dogs exhibit striking physical characteristics typical of the breed, such as pointed ears and dense fur.\n", + "\n", + "The second image shows a close-up of a single cat with striking blue eyes, likely a breed like the Siberian or Maine Coon, in a natural outdoor setting. The cat's fur is lighter, possibly a mix of white and gray, and it has a more subdued expression compared to the dogs. The background is blurred, suggesting a focus on the cat's face.\n", + "\n", + "Overall, the differences lie in the subjects (two dogs vs. one cat), the setting (lush, vibrant grassy field vs. a more muted outdoor background), and the overall mood and activity depicted (playful and active vs. serene and focused).\n" + ] + } + ], + "source": [ + "message = HumanMessage(\n", + " content=[\n", + " {\"type\": \"text\", \"text\": \"What are the difference between the two images? \"},\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\n", + " \"url\": \"https://cdn.pixabay.com/photo/2019/07/23/13/51/shepherd-dog-4357790_1280.jpg\"\n", + " },\n", + " },\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\n", + " \"url\": \"https://cdn.pixabay.com/photo/2024/02/17/00/18/cat-8578562_1280.jpg\"\n", + " },\n", + " },\n", + " ],\n", + ")\n", + "response = model.invoke([message])\n", + "print(response.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chaining" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=' Ich liebe Programmieren.\\n\\n', additional_kwargs={}, response_metadata={}, id='run-ffc4ace1-b73a-4fb3-ad0f-57e60a0f9b8d-0')" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "prompt = ChatPromptTemplate(\n", + " [\n", + " (\n", + " \"system\",\n", + " \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n", + " ),\n", + " (\"human\", \"{input}\"),\n", + " ]\n", + ")\n", + "\n", + "chain = prompt | model\n", + "chain.invoke(\n", + " {\n", + " \"input_language\": \"English\",\n", + " \"output_language\": \"German\",\n", + " \"input\": \"I love programming.\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use use with tavtly api search" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tool use and agent creation\n", + "\n", + "## Define the tools\n", + "\n", + "We first need to create the tools we want to use. Our main tool of choice will be Tavily - a search engine. We have a built-in tool in LangChain to easily use Tavily search engine as tool.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"TAVILY_API_KEY\"] = getpass.getpass(\"Enter your Tavily API key: \")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'url': 'https://www.weatherapi.com/', 'content': \"{'location': {'name': 'San Francisco', 'region': 'California', 'country': 'United States of America', 'lat': 37.775, 'lon': -122.4183, 'tz_id': 'America/Los_Angeles', 'localtime_epoch': 1730484342, 'localtime': '2024-11-01 11:05'}, 'current': {'last_updated_epoch': 1730484000, 'last_updated': '2024-11-01 11:00', 'temp_c': 11.1, 'temp_f': 52.0, 'is_day': 1, 'condition': {'text': 'Mist', 'icon': '//cdn.weatherapi.com/weather/64x64/day/143.png', 'code': 1030}, 'wind_mph': 2.9, 'wind_kph': 4.7, 'wind_degree': 247, 'wind_dir': 'WSW', 'pressure_mb': 1019.0, 'pressure_in': 30.08, 'precip_mm': 0.0, 'precip_in': 0.0, 'humidity': 100, 'cloud': 100, 'feelslike_c': 11.1, 'feelslike_f': 52.0, 'windchill_c': 10.3, 'windchill_f': 50.5, 'heatindex_c': 10.8, 'heatindex_f': 51.5, 'dewpoint_c': 10.4, 'dewpoint_f': 50.6, 'vis_km': 2.8, 'vis_miles': 1.0, 'uv': 3.0, 'gust_mph': 3.8, 'gust_kph': 6.1}}\"}, {'url': 'https://weatherspark.com/h/m/557/2024/1/Historical-Weather-in-January-2024-in-San-Francisco-California-United-States', 'content': 'San Francisco Temperature History January 2024\\nHourly Temperature in January 2024 in San Francisco\\nCompare San Francisco to another city:\\nCloud Cover in January 2024 in San Francisco\\nDaily Precipitation in January 2024 in San Francisco\\nObserved Weather in January 2024 in San Francisco\\nHours of Daylight and Twilight in January 2024 in San Francisco\\nSunrise & Sunset with Twilight in January 2024 in San Francisco\\nSolar Elevation and Azimuth in January 2024 in San Francisco\\nMoon Rise, Set & Phases in January 2024 in San Francisco\\nHumidity Comfort Levels in January 2024 in San Francisco\\nWind Speed in January 2024 in San Francisco\\nHourly Wind Speed in January 2024 in San Francisco\\nHourly Wind Direction in 2024 in San Francisco\\nAtmospheric Pressure in January 2024 in San Francisco\\nData Sources\\n See all nearby weather stations\\nLatest Report — 1:56 PM\\nFri, Jan 12, 2024\\xa0\\xa0\\xa0\\xa04 min ago\\xa0\\xa0\\xa0\\xa0UTC 21:56\\nCall Sign KSFO\\nTemp.\\n54.0°F\\nPrecipitation\\nNo Report\\nWind\\n8.1 mph\\nCloud Cover\\nMostly Cloudy\\n14,000 ft\\nRaw: KSFO 122156Z 08007KT 10SM FEW030 SCT050 BKN140 12/07 A3022 While having the tremendous advantages of temporal and spatial completeness, these reconstructions: (1) are based on computer models that may have model-based errors, (2) are coarsely sampled on a 50 km grid and are therefore unable to reconstruct the local variations of many microclimates, and (3) have particular difficulty with the weather in some coastal areas, especially small islands.\\n We further caution that our travel scores are only as good as the data that underpin them, that weather conditions at any given location and time are unpredictable and variable, and that the definition of the scores reflects a particular set of preferences that may not agree with those of any particular reader.\\n January 2024 Weather History in San Francisco California, United States\\nThe data for this report comes from the San Francisco International Airport.'}]\n" + ] + } + ], + "source": [ + "from langchain_community.tools.tavily_search import TavilySearchResults\n", + "\n", + "search = TavilySearchResults(max_results=2)\n", + "search_results = search.invoke(\"what is the weather in SF\")\n", + "print(search_results)\n", + "# If we want, we can create other tools.\n", + "# Once we have all the tools we want, we can put them in a list that we will reference later.\n", + "tools = [search]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now see what it is like to enable this model to do tool calling. In order to enable that we use .bind_tools to give the language model knowledge of these tools\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "model_with_tools = model.bind_tools(tools)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now call the model. Let's first call it with a normal message, and see how it responds. We can look at both the content field as well as the tool_calls field.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ContentString: Hello! How can I help you today? If you have a question or need information on a specific topic, feel free to ask. Just type your search query and I'll do my best to assist using the available function.\n", + "\n", + "\n", + "ToolCalls: []\n" + ] + } + ], + "source": [ + "from langchain_core.messages import HumanMessage\n", + "\n", + "response = model_with_tools.invoke([HumanMessage(content=\"Hi!\")])\n", + "\n", + "print(f\"ContentString: {response.content}\")\n", + "print(f\"ToolCalls: {response.tool_calls}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's try calling it with some input that would expect a tool to be called.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ContentString: \n", + "ToolCalls: [{'name': 'tavily_search_results_json', 'args': {'query': 'weather in SF'}, 'id': '2548c622-3553-42df-8220-39fde0632bdb', 'type': 'tool_call'}]\n" + ] + } + ], + "source": [ + "response = model_with_tools.invoke([HumanMessage(content=\"What's the weather in SF?\")])\n", + "\n", + "print(f\"ContentString: {response.content}\")\n", + "print(f\"ToolCalls: {response.tool_calls}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that there's now no text content, but there is a tool call! It wants us to call the Tavily Search tool.\n", + "\n", + "This isn't calling that tool yet - it's just telling us to. In order to actually call it, we'll want to create our agent." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Create the agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have defined the tools and the LLM, we can create the agent. We will be using LangGraph to construct the agent. Currently, we are using a high level interface to construct the agent, but the nice thing about LangGraph is that this high-level interface is backed by a low-level, highly controllable API in case you want to modify the agent logic.\n", + "\n", + "Now, we can initialize the agent with the LLM and the tools.\n", + "\n", + "Note that we are passing in the model, not model_with_tools. That is because `create_react_agent` will call `.bind_tools` for us under the hood." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from langgraph.prebuilt import create_react_agent\n", + "\n", + "agent_executor = create_react_agent(model, tools)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's now try it out on an example where it should be invoking the tool" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='hi!', additional_kwargs={}, response_metadata={}, id='0ab1f3c7-9079-42d4-8a8a-13af5f6c226b'),\n", + " AIMessage(content=' Hello! How can I help you today? If you have a question or need information on a specific topic, feel free to ask. For example, you can start with a search query like \"latest news on climate change\" or \"biography of Albert Einstein\".\\n\\n', additional_kwargs={}, response_metadata={}, id='run-276d9dcd-13f3-481d-b562-8fe3962d9ba1-0')]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response = agent_executor.invoke({\"messages\": [HumanMessage(content=\"hi!\")]})\n", + "\n", + "response[\"messages\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to see exactly what is happening under the hood (and to make sure it's not calling a tool) we can take a look at the LangSmith trace: https://smith.langchain.com/public/2372d9c5-855a-45ee-80f2-94b63493563d/r" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='whats the weather in sf?', additional_kwargs={}, response_metadata={}, id='af276c61-3df7-4241-8cb0-81d1f1477bb3'),\n", + " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': '86da84b8-0d44-444f-8448-7f134f9afa41', 'type': 'function', 'function': {'name': 'tavily_search_results_json', 'arguments': '{\"query\": \"weather in SF\"}'}}]}, response_metadata={}, id='run-abe1b8e2-98a6-4f69-8f95-278ac8c141ff-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'weather in SF'}, 'id': '86da84b8-0d44-444f-8448-7f134f9afa41', 'type': 'tool_call'}]),\n", + " ToolMessage(content='[{\"url\": \"https://www.weatherapi.com/\", \"content\": \"{\\'location\\': {\\'name\\': \\'San Francisco\\', \\'region\\': \\'California\\', \\'country\\': \\'United States of America\\', \\'lat\\': 37.775, \\'lon\\': -122.4183, \\'tz_id\\': \\'America/Los_Angeles\\', \\'localtime_epoch\\': 1730483436, \\'localtime\\': \\'2024-11-01 10:50\\'}, \\'current\\': {\\'last_updated_epoch\\': 1730483100, \\'last_updated\\': \\'2024-11-01 10:45\\', \\'temp_c\\': 11.4, \\'temp_f\\': 52.5, \\'is_day\\': 1, \\'condition\\': {\\'text\\': \\'Mist\\', \\'icon\\': \\'//cdn.weatherapi.com/weather/64x64/day/143.png\\', \\'code\\': 1030}, \\'wind_mph\\': 2.2, \\'wind_kph\\': 3.6, \\'wind_degree\\': 237, \\'wind_dir\\': \\'WSW\\', \\'pressure_mb\\': 1019.0, \\'pressure_in\\': 30.08, \\'precip_mm\\': 0.0, \\'precip_in\\': 0.0, \\'humidity\\': 100, \\'cloud\\': 100, \\'feelslike_c\\': 11.8, \\'feelslike_f\\': 53.2, \\'windchill_c\\': 11.2, \\'windchill_f\\': 52.1, \\'heatindex_c\\': 11.7, \\'heatindex_f\\': 53.0, \\'dewpoint_c\\': 10.1, \\'dewpoint_f\\': 50.1, \\'vis_km\\': 2.8, \\'vis_miles\\': 1.0, \\'uv\\': 3.0, \\'gust_mph\\': 3.0, \\'gust_kph\\': 4.9}}\"}, {\"url\": \"https://www.timeanddate.com/weather/@z-us-94134/ext\", \"content\": \"Forecasted weather conditions the coming 2 weeks for San Francisco. Sign in. News. News Home; Astronomy News; Time Zone News ... 01 pm: Mon Nov 11: 60 / 53 °F: Tstorms early. Broken clouds. 54 °F: 19 mph: ↑: 70%: 58%: 0.20\\\\\" 0 (Low) 6:46 am: 5:00 pm * Updated Monday, October 28, 2024 2:24:10 pm San Francisco time - Weather by CustomWeather\"}]', name='tavily_search_results_json', id='de8c8d78-ae24-4a8a-9c73-795c1e4fdd41', tool_call_id='86da84b8-0d44-444f-8448-7f134f9afa41', artifact={'query': 'weather in SF', 'follow_up_questions': None, 'answer': None, 'images': [], 'results': [{'title': 'Weather in San Francisco', 'url': 'https://www.weatherapi.com/', 'content': \"{'location': {'name': 'San Francisco', 'region': 'California', 'country': 'United States of America', 'lat': 37.775, 'lon': -122.4183, 'tz_id': 'America/Los_Angeles', 'localtime_epoch': 1730483436, 'localtime': '2024-11-01 10:50'}, 'current': {'last_updated_epoch': 1730483100, 'last_updated': '2024-11-01 10:45', 'temp_c': 11.4, 'temp_f': 52.5, 'is_day': 1, 'condition': {'text': 'Mist', 'icon': '//cdn.weatherapi.com/weather/64x64/day/143.png', 'code': 1030}, 'wind_mph': 2.2, 'wind_kph': 3.6, 'wind_degree': 237, 'wind_dir': 'WSW', 'pressure_mb': 1019.0, 'pressure_in': 30.08, 'precip_mm': 0.0, 'precip_in': 0.0, 'humidity': 100, 'cloud': 100, 'feelslike_c': 11.8, 'feelslike_f': 53.2, 'windchill_c': 11.2, 'windchill_f': 52.1, 'heatindex_c': 11.7, 'heatindex_f': 53.0, 'dewpoint_c': 10.1, 'dewpoint_f': 50.1, 'vis_km': 2.8, 'vis_miles': 1.0, 'uv': 3.0, 'gust_mph': 3.0, 'gust_kph': 4.9}}\", 'score': 0.9989501, 'raw_content': None}, {'title': 'San Francisco, USA 14 day weather forecast - timeanddate.com', 'url': 'https://www.timeanddate.com/weather/@z-us-94134/ext', 'content': 'Forecasted weather conditions the coming 2 weeks for San Francisco. Sign in. News. News Home; Astronomy News; Time Zone News ... 01 pm: Mon Nov 11: 60 / 53 °F: Tstorms early. Broken clouds. 54 °F: 19 mph: ↑: 70%: 58%: 0.20\" 0 (Low) 6:46 am: 5:00 pm * Updated Monday, October 28, 2024 2:24:10 pm San Francisco time - Weather by CustomWeather', 'score': 0.9938309, 'raw_content': None}], 'response_time': 3.56}),\n", + " AIMessage(content=' The current weather in San Francisco is mist with a temperature of 11.4°C (52.5°F). There is a 100% humidity and the wind is blowing at 2.2 mph from the WSW direction. The forecast for the coming weeks shows a mix of cloudy and partly cloudy days with some chances of thunderstorms. Temperatures are expected to range between 53°F and 60°F.\\n\\n', additional_kwargs={}, response_metadata={}, id='run-de4207d6-e8e8-4382-ad16-4de0dcf0812a-0')]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response = agent_executor.invoke(\n", + " {\"messages\": [HumanMessage(content=\"whats the weather in sf?\")]}\n", + ")\n", + "response[\"messages\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can check out the LangSmith trace to make sure it's calling the search tool effectively.\n", + "\n", + "https://smith.langchain.com/public/013ef704-654b-4447-8428-637b343d646e/r" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We've seen how the agent can be called with `.invoke` to get a final response. If the agent executes multiple steps, this may take a while. To show intermediate progress, we can stream back messages as they occur.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': '2457d3ea-f001-4b8c-a1ed-3dc3d1381639', 'type': 'function', 'function': {'name': 'tavily_search_results_json', 'arguments': '{\"query\": \"weather in San Francisco\"}'}}]}, response_metadata={}, id='run-0363deab-84d2-4319-bb1e-b55b47fe2274-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'weather in San Francisco'}, 'id': '2457d3ea-f001-4b8c-a1ed-3dc3d1381639', 'type': 'tool_call'}])]}}\n", + "----\n", + "{'tools': {'messages': [ToolMessage(content='[{\"url\": \"https://www.weatherapi.com/\", \"content\": \"{\\'location\\': {\\'name\\': \\'San Francisco\\', \\'region\\': \\'California\\', \\'country\\': \\'United States of America\\', \\'lat\\': 37.775, \\'lon\\': -122.4183, \\'tz_id\\': \\'America/Los_Angeles\\', \\'localtime_epoch\\': 1730483636, \\'localtime\\': \\'2024-11-01 10:53\\'}, \\'current\\': {\\'last_updated_epoch\\': 1730483100, \\'last_updated\\': \\'2024-11-01 10:45\\', \\'temp_c\\': 11.4, \\'temp_f\\': 52.5, \\'is_day\\': 1, \\'condition\\': {\\'text\\': \\'Mist\\', \\'icon\\': \\'//cdn.weatherapi.com/weather/64x64/day/143.png\\', \\'code\\': 1030}, \\'wind_mph\\': 2.2, \\'wind_kph\\': 3.6, \\'wind_degree\\': 237, \\'wind_dir\\': \\'WSW\\', \\'pressure_mb\\': 1019.0, \\'pressure_in\\': 30.08, \\'precip_mm\\': 0.0, \\'precip_in\\': 0.0, \\'humidity\\': 100, \\'cloud\\': 100, \\'feelslike_c\\': 11.8, \\'feelslike_f\\': 53.2, \\'windchill_c\\': 11.2, \\'windchill_f\\': 52.1, \\'heatindex_c\\': 11.7, \\'heatindex_f\\': 53.0, \\'dewpoint_c\\': 10.1, \\'dewpoint_f\\': 50.1, \\'vis_km\\': 2.8, \\'vis_miles\\': 1.0, \\'uv\\': 3.0, \\'gust_mph\\': 3.0, \\'gust_kph\\': 4.9}}\"}, {\"url\": \"https://weather.com/weather/monthly/l/69bedc6a5b6e977993fb3e5344e3c06d8bc36a1fb6754c3ddfb5310a3c6d6c87\", \"content\": \"Weather.com brings you the most accurate monthly weather forecast for San Francisco, CA with average/record and high/low temperatures, precipitation and more. ... 11. 66 ° 55 ° 12. 69 ° 60\"}]', name='tavily_search_results_json', id='e675f99b-130f-4e98-8477-badd45938d9d', tool_call_id='2457d3ea-f001-4b8c-a1ed-3dc3d1381639', artifact={'query': 'weather in San Francisco', 'follow_up_questions': None, 'answer': None, 'images': [], 'results': [{'title': 'Weather in San Francisco', 'url': 'https://www.weatherapi.com/', 'content': \"{'location': {'name': 'San Francisco', 'region': 'California', 'country': 'United States of America', 'lat': 37.775, 'lon': -122.4183, 'tz_id': 'America/Los_Angeles', 'localtime_epoch': 1730483636, 'localtime': '2024-11-01 10:53'}, 'current': {'last_updated_epoch': 1730483100, 'last_updated': '2024-11-01 10:45', 'temp_c': 11.4, 'temp_f': 52.5, 'is_day': 1, 'condition': {'text': 'Mist', 'icon': '//cdn.weatherapi.com/weather/64x64/day/143.png', 'code': 1030}, 'wind_mph': 2.2, 'wind_kph': 3.6, 'wind_degree': 237, 'wind_dir': 'WSW', 'pressure_mb': 1019.0, 'pressure_in': 30.08, 'precip_mm': 0.0, 'precip_in': 0.0, 'humidity': 100, 'cloud': 100, 'feelslike_c': 11.8, 'feelslike_f': 53.2, 'windchill_c': 11.2, 'windchill_f': 52.1, 'heatindex_c': 11.7, 'heatindex_f': 53.0, 'dewpoint_c': 10.1, 'dewpoint_f': 50.1, 'vis_km': 2.8, 'vis_miles': 1.0, 'uv': 3.0, 'gust_mph': 3.0, 'gust_kph': 4.9}}\", 'score': 0.9968992, 'raw_content': None}, {'title': 'Monthly Weather Forecast for San Francisco, CA - weather.com', 'url': 'https://weather.com/weather/monthly/l/69bedc6a5b6e977993fb3e5344e3c06d8bc36a1fb6754c3ddfb5310a3c6d6c87', 'content': 'Weather.com brings you the most accurate monthly weather forecast for San Francisco, CA with average/record and high/low temperatures, precipitation and more. ... 11. 66 ° 55 ° 12. 69 ° 60', 'score': 0.97644573, 'raw_content': None}], 'response_time': 3.16})]}}\n", + "----\n", + "{'agent': {'messages': [AIMessage(content=' The current weather in San Francisco is misty with a temperature of 11.4°C (52.5°F). The wind is blowing at 2.2 mph (3.6 kph) from the WSW direction. The humidity is at 100%, and the visibility is 2.8 km (1.0 miles). The monthly forecast shows average temperatures ranging from 55°F to 66°F (13°C to 19°C) with some precipitation expected.\\n\\n', additional_kwargs={}, response_metadata={}, id='run-99ccf444-d286-4244-a5a5-7b1b511153a6-0')]}}\n", + "----\n" + ] + } + ], + "source": [ + "for chunk in agent_executor.stream(\n", + " {\"messages\": [HumanMessage(content=\"whats the weather in sf?\")]}\n", + "):\n", + " print(chunk)\n", + " print(\"----\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API reference" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "https://docs.reka.ai/quick-start" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langchain_reka", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 5be87606873d4..d331fb66e85dd 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -86,6 +86,7 @@ telethon>=1.28.5,<2 tidb-vector>=0.0.3,<1.0.0 timescale-vector==0.0.1 tqdm>=4.48.0 +tiktoken>=0.8.0 tree-sitter>=0.20.2,<0.21 tree-sitter-languages>=1.8.0,<2 upstash-redis>=1.1.0,<2 diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 0cff776c786d7..ec514566f3053 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -155,6 +155,9 @@ from langchain_community.chat_models.promptlayer_openai import ( PromptLayerChatOpenAI, ) + from langchain_community.chat_models.reka import ( + ChatReka, + ) from langchain_community.chat_models.sambanova import ( ChatSambaNovaCloud, ChatSambaStudio, @@ -226,6 +229,7 @@ "ChatOllama", "ChatOpenAI", "ChatPerplexity", + "ChatReka", "ChatPremAI", "ChatSambaNovaCloud", "ChatSambaStudio", @@ -290,6 +294,7 @@ "ChatOCIModelDeploymentTGI": "langchain_community.chat_models.oci_data_science", "ChatOllama": "langchain_community.chat_models.ollama", "ChatOpenAI": "langchain_community.chat_models.openai", + "ChatReka": "langchain_community.chat_models.reka", "ChatPerplexity": "langchain_community.chat_models.perplexity", "ChatSambaNovaCloud": "langchain_community.chat_models.sambanova", "ChatSambaStudio": "langchain_community.chat_models.sambanova", diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py new file mode 100644 index 0000000000000..f56001f37b160 --- /dev/null +++ b/libs/community/langchain_community/chat_models/reka.py @@ -0,0 +1,435 @@ +import json +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Sequence, + Type, + Union, +) + +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.chat_models import ( + BaseChatModel, + agenerate_from_stream, + generate_from_stream, +) +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils.function_calling import convert_to_openai_tool +from pydantic import BaseModel, ConfigDict, Field, model_validator + +DEFAULT_REKA_MODEL = "reka-flash" + +ContentType = Union[str, List[Union[str, Dict[str, Any]]]] + + +def process_content_item(item: Dict[str, Any]) -> Dict[str, Any]: + """Process a single content item.""" + if item["type"] == "image_url": + image_url = item["image_url"] + if isinstance(image_url, dict) and "url" in image_url: + # If it's in LangChain format, extract the URL value + item["image_url"] = image_url["url"] + return item + + +def process_content(content: ContentType) -> List[Dict[str, Any]]: + """Process content to handle both text and media inputs, + returning a list of content items.""" + if isinstance(content, str): + return [{"type": "text", "text": content}] + elif isinstance(content, list): + result = [] + for item in content: + if isinstance(item, str): + result.append({"type": "text", "text": item}) + elif isinstance(item, dict): + result.append(process_content_item(item)) + else: + raise ValueError(f"Invalid content item format: {item}") + return result + else: + raise ValueError("Invalid content format") + + +def convert_to_reka_messages(messages: List[BaseMessage]) -> List[Dict[str, Any]]: + """Convert LangChain messages to Reka message format.""" + reka_messages: List[Dict[str, Any]] = [] + system_message: Optional[str] = None + + for message in messages: + if isinstance(message, SystemMessage): + if system_message is None: + if isinstance(message.content, str): + system_message = message.content + else: + raise TypeError("SystemMessage content must be a string.") + else: + raise ValueError("Multiple system messages are not supported.") + elif isinstance(message, HumanMessage): + processed_content = process_content(message.content) + if system_message: + if ( + processed_content + and isinstance(processed_content[0], dict) + and processed_content[0].get("type") == "text" + and "text" in processed_content[0] + ): + processed_content[0]["text"] = ( + f"{system_message}\n{processed_content[0]['text']}" + ) + else: + processed_content.insert( + 0, {"type": "text", "text": system_message} + ) + system_message = None + reka_messages.append({"role": "user", "content": processed_content}) + elif isinstance(message, AIMessage): + reka_message: Dict[str, Any] = {"role": "assistant"} + if message.content: + processed_content = process_content(message.content) + reka_message["content"] = processed_content + if "tool_calls" in message.additional_kwargs: + tool_calls = message.additional_kwargs["tool_calls"] + formatted_tool_calls = [] + for tool_call in tool_calls: + formatted_tool_call = { + "id": tool_call["id"], + "name": tool_call["function"]["name"], + "parameters": json.loads(tool_call["function"]["arguments"]), + } + formatted_tool_calls.append(formatted_tool_call) + reka_message["tool_calls"] = formatted_tool_calls + reka_messages.append(reka_message) + elif isinstance(message, ToolMessage): + content_list: List[Dict[str, Any]] = [] + content_list.append( + { + "tool_call_id": message.tool_call_id, + "output": json.dumps({"status": message.content}), + } + ) + reka_messages.append( + { + "role": "tool_output", + "content": content_list, + } + ) + else: + raise ValueError(f"Unsupported message type: {type(message)}") + + return reka_messages + + +class ChatReka(BaseChatModel): + """Reka chat large language models.""" + + client: Any = None #: :meta private: + async_client: Any = None #: :meta private: + model: str = Field(default=DEFAULT_REKA_MODEL) + max_tokens: int = Field(default=256) + temperature: Optional[float] = None + streaming: bool = False + default_request_timeout: Optional[float] = None + max_retries: int = 2 + reka_api_key: Optional[str] = None + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + model_config = ConfigDict(extra="forbid") + token_counter: Optional[ + Callable[[Union[str, BaseMessage, List[BaseMessage]]], int] + ] = None + + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate that API key and Python package exist in the environment.""" + reka_api_key = values.get("reka_api_key") + reka_api_key = get_from_dict_or_env( + {"reka_api_key": reka_api_key}, "reka_api_key", "REKA_API_KEY" + ) + values["reka_api_key"] = reka_api_key + + try: + # Import reka libraries here + from reka.client import AsyncReka, Reka + + values["client"] = Reka( + api_key=reka_api_key, + ) + values["async_client"] = AsyncReka( + api_key=reka_api_key, + ) + except ImportError: + raise ImportError( + "Could not import Reka Python package. " + "Please install it with `pip install reka-api`." + ) + return values + + @property + def _default_params(self) -> Mapping[str, Any]: + """Get the default parameters for calling Reka API.""" + params = { + "model": self.model, + "max_tokens": self.max_tokens, + } + if self.temperature is not None: + params["temperature"] = self.temperature + return {**params, **self.model_kwargs} + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "reka-chat" + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + reka_messages = convert_to_reka_messages(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + + stream = self.client.chat.create_stream(messages=reka_messages, **params) + + for chunk in stream: + content = chunk.responses[0].chunk.content + chat_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) + if run_manager: + run_manager.on_llm_new_token(content, chunk=chat_chunk) + yield chat_chunk + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + reka_messages = convert_to_reka_messages(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + + stream = self.async_client.chat.create_stream(messages=reka_messages, **params) + + async for chunk in stream: + content = chunk.responses[0].chunk.content + chat_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) + if run_manager: + await run_manager.on_llm_new_token(content, chunk=chat_chunk) + yield chat_chunk + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + return generate_from_stream( + self._stream(messages, stop=stop, run_manager=run_manager, **kwargs) + ) + + reka_messages = convert_to_reka_messages(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + response = self.client.chat.create(messages=reka_messages, **params) + + if response.responses[0].message.tool_calls: + tool_calls = response.responses[0].message.tool_calls + message = AIMessage( + content="", # Empty string instead of None + additional_kwargs={ + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.parameters), + }, + } + for tc in tool_calls + ] + }, + ) + else: + content = response.responses[0].message.content + # Ensure content is never None + message = AIMessage(content=content if content is not None else "") + + return ChatResult(generations=[ChatGeneration(message=message)]) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + return await agenerate_from_stream( + self._astream(messages, stop=stop, run_manager=run_manager, **kwargs) + ) + + reka_messages = convert_to_reka_messages(messages) + params = {**self._default_params, **kwargs} + if stop: + params["stop"] = stop + response = await self.async_client.chat.create(messages=reka_messages, **params) + + if response.responses[0].message.tool_calls: + tool_calls = response.responses[0].message.tool_calls + message = AIMessage( + content="", # Empty string instead of None + additional_kwargs={ + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.parameters), + }, + } + for tc in tool_calls + ] + }, + ) + else: + content = response.responses[0].message.content + # Ensure content is never None + message = AIMessage(content=content if content is not None else "") + + return ChatResult(generations=[ChatGeneration(message=message)]) + + def get_num_tokens(self, input: Union[str, BaseMessage, List[BaseMessage]]) -> int: + """Calculate number of tokens. + + Args: + input: Either a string, a single BaseMessage, or a list of BaseMessages. + + Returns: + int: Number of tokens in the input. + + Raises: + ImportError: If tiktoken is not installed. + ValueError: If message content is not a string. + """ + if self.token_counter is not None: + return self.token_counter(input) + + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "Please install it with `pip install tiktoken`." + ) + + encoding = tiktoken.get_encoding("cl100k_base") + + if isinstance(input, str): + return len(encoding.encode(input)) + elif isinstance(input, BaseMessage): + content = input.content + if not isinstance(content, str): + raise ValueError( + f"Message content must be a string, got {type(content)}" + ) + return len(encoding.encode(content)) + elif isinstance(input, list): + total = 0 + for msg in input: + content = msg.content + if not isinstance(content, str): + raise ValueError( + f"Message content must be a string, got {type(content)}" + ) + total += len(encoding.encode(content)) + return total + else: + raise TypeError(f"Unsupported input type: {type(input)}") + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + *, + tool_choice: str = "auto", + strict: Optional[bool] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + The `tool_choice` parameter controls how the model uses the tools you pass. + There are three available options: + + - `"auto"`: Lets the model decide whether or not to invoke a tool. This is the + recommended way to do function calling with our models. + - `"none"`: Disables tool calling. In this case, even if you pass tools to + the model, the model will not invoke any tools. + - `"tool"`: Forces the model to invoke one or more of the tools it has + been passed. + + Args: + tools: A list of tool definitions to bind to this chat model. + Supports any tool definition handled by + :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. + tool_choice: Controls how the model uses the tools you pass. + Options are "auto", "none", or "tool". Defaults to "auto". + strict: + If True, model output is guaranteed to exactly match the JSON Schema + provided in the tool definition. + If False, input schema will not be validated + and model output will not be validated. + If None, ``strict`` argument will not + be passed to the model. + kwargs: Any additional parameters are passed directly to the model. + + Returns: + Runnable: An executable chain or component. + """ + formatted_tools = [ + convert_to_openai_tool(tool, strict=strict) for tool in tools + ] + + # Ensure tool_choice is one of the allowed options + if tool_choice not in ("auto", "none", "tool"): + raise ValueError( + f"Invalid tool_choice '{tool_choice}' provided. " + "Tool choice must be one of: 'auto', 'none', or 'tool'." + ) + + # Map tool_choice to the parameter expected by the Reka API + kwargs["tool_choice"] = tool_choice + + # Pass the tools and updated kwargs to the model + formatted_tools = [tool["function"] for tool in formatted_tools] + return super().bind(tools=formatted_tools, **kwargs) diff --git a/libs/community/tests/integration_tests/chat_models/test_reka.py b/libs/community/tests/integration_tests/chat_models/test_reka.py new file mode 100644 index 0000000000000..848a0f04bcf87 --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_reka.py @@ -0,0 +1,222 @@ +"""Test Reka API wrapper.""" + +import logging +from typing import List + +import pytest +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.outputs import ChatGeneration, LLMResult + +from langchain_community.chat_models.reka import ChatReka +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_call() -> None: + """Test a simple call to Reka.""" + chat = ChatReka(model="reka-flash", verbose=True) + message = HumanMessage(content="Hello") + response = chat.invoke([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + logger.debug(f"Response content: {response.content}") + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_generate() -> None: + """Test the generate method of Reka.""" + chat = ChatReka(model="reka-flash", verbose=True) + chat_messages: List[List[BaseMessage]] = [ + [HumanMessage(content="How many toes do dogs have?")] + ] + messages_copy = [messages.copy() for messages in chat_messages] + result: LLMResult = chat.generate(chat_messages) + assert isinstance(result, LLMResult) + for response in result.generations[0]: + assert isinstance(response, ChatGeneration) + assert isinstance(response.text, str) + assert response.text == response.message.content + logger.debug(f"Generated response: {response.text}") + assert chat_messages == messages_copy + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_streaming() -> None: + """Test streaming tokens from Reka.""" + chat = ChatReka(model="reka-flash", streaming=True, verbose=True) + message = HumanMessage(content="Tell me a story.") + response = chat.invoke([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + logger.debug(f"Streaming response content: {response.content}") + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_streaming_callback() -> None: + """Test that streaming correctly invokes callbacks.""" + callback_handler = FakeCallbackHandler() + chat = ChatReka( + model="reka-flash", + streaming=True, + callbacks=[callback_handler], + verbose=True, + ) + message = HumanMessage(content="Write me a sentence with 10 words.") + chat.invoke([message]) + assert callback_handler.llm_streams > 1 + logger.debug(f"Number of LLM streams: {callback_handler.llm_streams}") + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +async def test_reka_async_streaming_callback() -> None: + """Test asynchronous streaming with callbacks.""" + callback_handler = FakeCallbackHandler() + chat = ChatReka( + model="reka-flash", + streaming=True, + callbacks=[callback_handler], + verbose=True, + ) + chat_messages: List[BaseMessage] = [ + HumanMessage(content="How many toes do dogs have?") + ] + result: LLMResult = await chat.agenerate([chat_messages]) + assert callback_handler.llm_streams > 1 + assert isinstance(result, LLMResult) + for response in result.generations[0]: + assert isinstance(response, ChatGeneration) + assert isinstance(response.text, str) + assert response.text == response.message.content + logger.debug(f"Async generated response: {response.text}") + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_tool_usage_integration() -> None: + """Test tool usage with Reka API integration.""" + # Initialize the ChatReka model with tools and verbose logging + chat_reka = ChatReka(model="reka-flash", verbose=True) + tools = [ + { + "type": "function", + "function": { + "name": "get_product_availability", + "description": ( + "Determine whether a product is currently in stock given " + "a product ID." + ), + "parameters": { + "type": "object", + "properties": { + "product_id": { + "type": "string", + "description": ( + "The unique product ID to check availability for" + ), + }, + }, + "required": ["product_id"], + }, + }, + }, + ] + chat_reka_with_tools = chat_reka.bind_tools(tools) + + # Start a conversation + messages: List[BaseMessage] = [ + HumanMessage(content="Is product A12345 in stock right now?") + ] + + # Get the initial response + response = chat_reka_with_tools.invoke(messages) + assert isinstance(response, AIMessage) + logger.debug(f"Initial AI message: {response.content}") + + # Check if the model wants to use a tool + if "tool_calls" in response.additional_kwargs: + tool_calls = response.additional_kwargs["tool_calls"] + for tool_call in tool_calls: + function_name = tool_call["function"]["name"] + arguments = tool_call["function"]["arguments"] + logger.debug( + f"Tool call requested: {function_name} with arguments {arguments}" + ) + + # Simulate executing the tool + tool_output = "AVAILABLE" + + tool_message = ToolMessage( + content=tool_output, tool_call_id=tool_call["id"] + ) + messages.append(response) + messages.append(tool_message) + + final_response = chat_reka_with_tools.invoke(messages) + assert isinstance(final_response, AIMessage) + logger.debug(f"Final AI message: {final_response.content}") + + # Assert that the response message is non-empty + assert final_response.content, "The final response content is empty." + else: + pytest.fail("The model did not request a tool.") + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_system_message() -> None: + """Test Reka with system message.""" + chat = ChatReka(model="reka-flash", verbose=True) + messages = [ + SystemMessage(content="You are a helpful AI that speaks like Shakespeare."), + HumanMessage(content="Tell me about the weather today."), + ] + response = chat.invoke(messages) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + logger.debug(f"Response with system message: {response.content}") + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_system_message_multi_turn() -> None: + """Test multi-turn conversation with system message.""" + chat = ChatReka(model="reka-flash", verbose=True) + messages = [ + SystemMessage(content="You are a math tutor who explains concepts simply."), + HumanMessage(content="What is a prime number?"), + ] + + # First turn + response1 = chat.invoke(messages) + assert isinstance(response1, AIMessage) + messages.append(response1) + + # Second turn + messages.append(HumanMessage(content="Can you give me an example?")) + response2 = chat.invoke(messages) + assert isinstance(response2, AIMessage) + + logger.debug(f"First response: {response1.content}") + logger.debug(f"Second response: {response2.content}") diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index ee3240168e01e..4022fe781bf02 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -45,6 +45,7 @@ "ChatVertexAI", "ChatYandexGPT", "ChatYuan2", + "ChatReka", "ChatZhipuAI", "ErnieBotChat", "FakeListChatModel", diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py new file mode 100644 index 0000000000000..bbacadf7fd926 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -0,0 +1,372 @@ +import json +import os +from typing import Any, Dict, List +from unittest.mock import MagicMock, patch + +import pytest +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from pydantic import ValidationError + +from langchain_community.chat_models import ChatReka +from langchain_community.chat_models.reka import ( + convert_to_reka_messages, + process_content, +) + +os.environ["REKA_API_KEY"] = "dummy_key" + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_model_param() -> None: + llm = ChatReka(model="reka-flash") + assert llm.model == "reka-flash" + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_model_kwargs() -> None: + llm = ChatReka(model_kwargs={"foo": "bar"}) + assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_incorrect_field() -> None: + """Test that providing an incorrect field raises ValidationError.""" + with pytest.raises(ValidationError): + ChatReka(unknown_field="bar") # type: ignore + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_initialization() -> None: + """Test Reka initialization.""" + # Verify that ChatReka can be initialized using a secret key provided + # as a parameter rather than an environment variable. + ChatReka(model="reka-flash", reka_api_key="test_key") + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +@pytest.mark.parametrize( + ("content", "expected"), + [ + ("Hello", [{"type": "text", "text": "Hello"}]), + ( + [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ], + [ + {"type": "text", "text": "Describe this image"}, + {"type": "image_url", "image_url": "https://example.com/image.jpg"}, + ], + ), + ( + [ + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.jpg"}, + }, + ], + [ + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": "https://example.com/image.jpg"}, + ], + ), + ], +) +def test_process_content(content: Any, expected: List[Dict[str, Any]]) -> None: + result = process_content(content) + assert result == expected + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +@pytest.mark.parametrize( + ("messages", "expected"), + [ + ( + [HumanMessage(content="Hello")], + [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + ), + ( + [ + HumanMessage( + content=[ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ] + ), + AIMessage(content="It's a beautiful landscape."), + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "It's a beautiful landscape."} + ], + }, + ], + ), + ], +) +def test_convert_to_reka_messages( + messages: List[BaseMessage], expected: List[Dict[str, Any]] +) -> None: + result = convert_to_reka_messages(messages) + assert result == expected + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_streaming() -> None: + llm = ChatReka(streaming=True) + assert llm.streaming is True + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_temperature() -> None: + llm = ChatReka(temperature=0.5) + assert llm.temperature == 0.5 + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_max_tokens() -> None: + llm = ChatReka(max_tokens=100) + assert llm.max_tokens == 100 + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_default_params() -> None: + llm = ChatReka() + assert llm._default_params == { + "max_tokens": 256, + "model": "reka-flash", + } + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_identifying_params() -> None: + """Test that ChatReka identifies its default parameters correctly.""" + chat = ChatReka(model="reka-flash", temperature=0.7, max_tokens=256) + expected_params = { + "model": "reka-flash", + "temperature": 0.7, + "max_tokens": 256, + } + assert chat._default_params == expected_params + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_llm_type() -> None: + llm = ChatReka() + assert llm._llm_type == "reka-chat" + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_reka_tool_use_with_mocked_response() -> None: + with patch("reka.client.Reka") as MockReka: + # Mock the Reka client + mock_client = MockReka.return_value + mock_chat = MagicMock() + mock_client.chat = mock_chat + mock_response = MagicMock() + mock_message = MagicMock() + mock_tool_call = MagicMock() + mock_tool_call.id = "tool_call_1" + mock_tool_call.name = "search_tool" + mock_tool_call.parameters = {"query": "LangChain"} + mock_message.tool_calls = [mock_tool_call] + mock_message.content = None + mock_response.responses = [MagicMock(message=mock_message)] + mock_chat.create.return_value = mock_response + + llm = ChatReka() + messages: List[BaseMessage] = [HumanMessage(content="Tell me about LangChain")] + result = llm._generate(messages) + + assert len(result.generations) == 1 + ai_message = result.generations[0].message + assert ai_message.content == "" + assert "tool_calls" in ai_message.additional_kwargs + tool_calls = ai_message.additional_kwargs["tool_calls"] + assert len(tool_calls) == 1 + assert tool_calls[0]["id"] == "tool_call_1" + assert tool_calls[0]["function"]["name"] == "search_tool" + assert tool_calls[0]["function"]["arguments"] == json.dumps( + {"query": "LangChain"} + ) + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +@pytest.mark.parametrize( + ("messages", "expected"), + [ + # Test single system message + ( + [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Hello"), + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "You are a helpful assistant.\nHello"} + ], + } + ], + ), + # Test system message with multiple messages + ( + [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="What is 2+2?"), + AIMessage(content="4"), + HumanMessage(content="Thanks!"), + ], + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant.\nWhat is 2+2?", + } + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": "4"}]}, + {"role": "user", "content": [{"type": "text", "text": "Thanks!"}]}, + ], + ), + # Test system message with media content + ( + [ + SystemMessage(content="Hi."), + HumanMessage( + content=[ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ] + ), + ], + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Hi.\nWhat's in this image?", + }, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ], + }, + ], + ), + ], +) +def test_system_message_handling( + messages: List[BaseMessage], expected: List[Dict[str, Any]] +) -> None: + """Test that system messages are handled correctly.""" + result = convert_to_reka_messages(messages) + assert result == expected + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_multiple_system_messages_error() -> None: + """Test that multiple system messages raise an error.""" + messages = [ + SystemMessage(content="System message 1"), + SystemMessage(content="System message 2"), + HumanMessage(content="Hello"), + ] + + with pytest.raises(ValueError, match="Multiple system messages are not supported."): + convert_to_reka_messages(messages) + + +@pytest.mark.skip( + reason="Dependency conflict w/ other dependencies for urllib3 versions." +) +def test_get_num_tokens() -> None: + """Test that token counting works correctly for different input types.""" + llm = ChatReka() + import tiktoken + + encoding = tiktoken.get_encoding("cl100k_base") + + # Test string input + text = "What is the weather like today?" + expected_tokens = len(encoding.encode(text)) + assert llm.get_num_tokens(text) == expected_tokens + + # Test BaseMessage input + message = HumanMessage(content="What is the weather like today?") + assert isinstance(message.content, str) + expected_tokens = len(encoding.encode(message.content)) + assert llm.get_num_tokens(message) == expected_tokens + + # Test List[BaseMessage] input + messages = [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Hi!"), + AIMessage(content="Hello! How can I help you today?"), + ] + expected_tokens = sum( + len(encoding.encode(msg.content)) + for msg in messages + if isinstance(msg.content, str) + ) + assert llm.get_num_tokens(messages) == expected_tokens + + # Test empty message list + assert llm.get_num_tokens([]) == 0