diff --git a/docs/docs/integrations/chat/writer.ipynb b/docs/docs/integrations/chat/writer.ipynb index a76752ef2f64c..d7a47b9c4767a 100644 --- a/docs/docs/integrations/chat/writer.ipynb +++ b/docs/docs/integrations/chat/writer.ipynb @@ -17,7 +17,7 @@ "source": [ "# ChatWriter\n", "\n", - "This notebook provides a quick overview for getting started with Writer [chat models](/docs/concepts/chat_models).\n", + "This notebook provides a quick overview for getting started with Writer [chat models](/docs/concepts/#chat-models).\n", "\n", "Writer has several chat models. You can find information about their latest models and their costs, context windows, and supported input types in the [Writer docs](https://dev.writer.com/home/models).\n", "\n", @@ -25,21 +25,20 @@ ] }, { - "cell_type": "markdown", - "id": "e49f1e0d", "metadata": {}, + "cell_type": "markdown", "source": [ "## Overview\n", "\n", "### Integration details\n", - "| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/openai) | Package downloads | Package latest |\n", - "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", - "| ChatWriter | langchain-community | ❌ | ❌ | ❌ | ❌ | ❌ |\n", + "| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n", + "| :--- | :--- | :---: | :---: |:----------:| :---: | :---: |\n", + "| ChatWriter | langchain-community | ❌ | ❌ | ❌ | ❌ | ❌ |\n", "\n", "### Model features\n", - "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | Image input | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", - "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", - "| ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | \n", + "| [Tool calling](/docs/how_to/tool_calling) | Structured output | JSON mode | Image input | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | Logprobs |\n", + "| :---: |:-----------------:| :---: | :---: | :---: | :---: | :---: | :---: |:--------------------------------:|:--------:|\n", + "| ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ |\n", "\n", "## Setup\n", "\n", @@ -48,15 +47,16 @@ "### Credentials\n", "\n", "Head to [Writer AI Studio](https://app.writer.com/aistudio/signup?utm_campaign=devrel) to sign up to OpenAI and generate an API key. Once you've done this set the WRITER_API_KEY environment variable:" - ] + ], + "id": "617a6e98205ab7c8" }, { "cell_type": "code", "id": "e817fe2e-4f1d-4533-b19e-2400b1cf6ce8", "metadata": { "ExecuteTime": { - "end_time": "2024-10-24T13:51:54.323678Z", - "start_time": "2024-10-24T13:51:42.127404Z" + "end_time": "2024-11-14T09:46:26.800627Z", + "start_time": "2024-11-14T09:27:59.652281Z" } }, "source": [ @@ -64,7 +64,7 @@ "import os\n", "\n", "if not os.environ.get(\"WRITER_API_KEY\"):\n", - " os.environ[\"WRITER_API_KEY\"] = getpass.getpass(\"Enter your Writer API key: \")" + " os.environ[\"WRITER_API_KEY\"] = getpass.getpass(\"Enter your Writer API key:\")" ], "outputs": [], "execution_count": 1 @@ -84,23 +84,24 @@ "id": "2113471c-75d7-45df-b784-d78da4ef7aba", "metadata": { "ExecuteTime": { - "end_time": "2024-10-24T13:52:49.262240Z", - "start_time": "2024-10-24T13:52:47.564879Z" + "end_time": "2024-11-14T09:46:32.415354Z", + "start_time": "2024-11-14T09:46:26.826112Z" } }, - "source": [ - "%pip install -qU langchain-community writer-sdk" - ], + "source": "%pip install -qU langchain-community writer-sdk", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m24.2\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m24.3.1\u001B[0m\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], - "execution_count": 4 + "execution_count": 2 }, { "cell_type": "markdown", @@ -118,8 +119,8 @@ "metadata": { "tags": [], "ExecuteTime": { - "end_time": "2024-10-24T13:52:38.822950Z", - "start_time": "2024-10-24T13:52:38.674441Z" + "end_time": "2024-11-14T09:46:33.504711Z", + "start_time": "2024-11-14T09:46:32.574505Z" } }, "source": [ @@ -129,24 +130,10 @@ " model=\"palmyra-x-004\",\n", " temperature=0.7,\n", " max_tokens=1000,\n", - " # api_key=\"...\", # if you prefer to pass api key in directly instaed of using env vars\n", - " # base_url=\"...\",\n", " # other params...\n", ")" ], - "outputs": [ - { - "ename": "ImportError", - "evalue": "cannot import name 'ChatWriter' from 'langchain_community.chat_models' (/home/yanomaly/PycharmProjects/whitesnake/writer/langсhain/libs/community/langchain_community/chat_models/__init__.py)", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mImportError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[0;32mIn[3], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mlangchain_community\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mchat_models\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ChatWriter\n\u001B[1;32m 3\u001B[0m llm \u001B[38;5;241m=\u001B[39m ChatWriter(\n\u001B[1;32m 4\u001B[0m model\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpalmyra-x-004\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[1;32m 5\u001B[0m temperature\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0.7\u001B[39m,\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 9\u001B[0m \u001B[38;5;66;03m# other params...\u001B[39;00m\n\u001B[1;32m 10\u001B[0m )\n", - "\u001B[0;31mImportError\u001B[0m: cannot import name 'ChatWriter' from 'langchain_community.chat_models' (/home/yanomaly/PycharmProjects/whitesnake/writer/langсhain/libs/community/langchain_community/chat_models/__init__.py)" - ] - } - ], + "outputs": [], "execution_count": 3 }, { @@ -159,12 +146,14 @@ }, { "cell_type": "code", - "execution_count": null, "id": "ce16ad78-8e6f-48cd-954e-98be75eb5836", "metadata": { - "tags": [] + "tags": [], + "ExecuteTime": { + "end_time": "2024-11-14T09:46:38.856174Z", + "start_time": "2024-11-14T09:46:33.520062Z" + } }, - "outputs": [], "source": [ "messages = [\n", " (\n", @@ -173,19 +162,127 @@ " ),\n", " (\"human\", \"Write a poem about Python.\"),\n", "]\n", - "ai_msg = llm.invoke(messages)\n", - "ai_msg" - ] + "ai_msg = llm.invoke(messages)" + ], + "outputs": [], + "execution_count": 4 }, { "cell_type": "code", - "execution_count": null, "id": "2cd224b8-4499-41fb-a604-d53a7ff17b2e", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:38.866651Z", + "start_time": "2024-11-14T09:46:38.863817Z" + } + }, + "source": [ + "print(ai_msg.content)" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In realms of code, where logic weaves and flows,\n", + "A language rises, Python by its name,\n", + "With syntax clear, where elegance it shows,\n", + "A serpent, wise, that time and space can tame.\n", + "\n", + "Born from the mind of Guido, pure and bright,\n", + "Its beauty lies in simplicity and grace,\n", + "A tool of power, yet gentle in its might,\n", + "In every programmer's heart, a cherished place.\n", + "\n", + "It dances through the data, vast and deep,\n", + "With libraries that span the digital realm,\n", + "From machine learning's secrets to keep,\n", + "To web development, it wields the helm.\n", + "\n", + "In the hands of the novice and the sage,\n", + "Python spins the threads of digital dreams,\n", + "A language that can turn the age,\n", + "With a gentle learning curve, its appeal gleams.\n", + "\n", + "It's more than code, a community it builds,\n", + "Where knowledge freely flows, and all are heard,\n", + "In Python's world, the future unfolds,\n", + "A language of the people, for the world.\n", + "\n", + "So here's to Python, in its gentle might,\n", + "A master of the modern coding art,\n", + "May it continue to light our path each night,\n", + "In the vast, evolving world of code, its heart.\n" + ] + } + ], + "execution_count": 5 + }, + { "metadata": {}, + "cell_type": "markdown", + "source": "## Streaming", + "id": "35b3a5b3dabef65" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:38.914883Z", + "start_time": "2024-11-14T09:46:38.912564Z" + } + }, + "cell_type": "code", + "source": "ai_stream = llm.stream(messages)", + "id": "2725770182bf96dc", "outputs": [], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:43.226449Z", + "start_time": "2024-11-14T09:46:38.955512Z" + } + }, + "cell_type": "code", "source": [ - "print(ai_msg.content)" - ] + "for chunk in ai_stream:\n", + " print(chunk.content, end=\"\")" + ], + "id": "a48410d9488162e3", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In realms of code where logic weaves,\n", + "A language rises, Python, it breezes,\n", + "With syntax clear and simple to read,\n", + "Through its elegance, our spirits are fed.\n", + "\n", + "Like rivers flowing, smooth and serene,\n", + "Its structure harmonious, a coder's dream,\n", + "Indentations guide the flow of control,\n", + "In Python's world, confusion takes no toll.\n", + "\n", + "A vast library, a treasure trove so bright,\n", + "For web and data, it offers its might,\n", + "With modules and packages, a rich array,\n", + "Python empowers us to code in play.\n", + "\n", + "From AI to scripts, in flexibility it thrives,\n", + "A language of the future, as many now derive,\n", + "Its community, a beacon of support and cheer,\n", + "With Python, the possibilities are vast, far and near.\n", + "\n", + "So here's to Python, in its gentle grace,\n", + "A tool that enhances, a language that embraces,\n", + "The art of coding, with a fluent, flowing pen,\n", + "In the Python world, we code, and we begin." + ] + } + ], + "execution_count": 7 }, { "cell_type": "markdown", @@ -199,12 +296,14 @@ }, { "cell_type": "code", - "execution_count": null, "id": "fbb043e6", "metadata": { - "tags": [] + "tags": [], + "ExecuteTime": { + "end_time": "2024-11-14T09:46:50.721645Z", + "start_time": "2024-11-14T09:46:43.234590Z" + } }, - "outputs": [], "source": [ "from langchain_core.prompts import ChatPromptTemplate\n", "\n", @@ -225,7 +324,20 @@ " \"input\": \"Write a poem about Java.\",\n", " }\n", ")" - ] + ], + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessageChunk(content='In the realm of code, where logic weaves and flows, \\nA language rises, like a phoenix from the code\\'s throes. \\nJava, the name, a cup of coffee\\'s steam, \\nBrewed in the minds, where digital dreams gleam.\\n\\nWith syntax clear, like morning\\'s misty hue, \\nIn classes and objects, it spins a tale so true. \\nA platform agnostic, with a byte to spare, \\nAcross the devices, it journeys everywhere.\\n\\nInheritance and polymorphism, its power\\'s core, \\nLike ancient runes, in every line they bore. \\nEncapsulation, a shield, with data it does hide, \\nIn the vast jungle of code, it stands as a guide.\\n\\nFrom applets small, to vast, server-side apps, \\nIts threads run swift, through the computing traps. \\nA language of the people, by the people, for the people’s use, \\nBuilt on the principle, \"write once, run anywhere, with no excuse.\"\\n\\nIn the heart of Android, it beats, a steady drum, \\nCrafting experiences, in every smartphone\\'s hum. \\nIn the cloud, in the enterprise, its presence is vast, \\nA cornerstone of computing, built to last.\\n\\nOh Java, thy elegance, thy robust design, \\nA language that stands, in any computing line. \\nWith every update, with every new release, \\nThy community grows, with a vibrant, diverse peace.\\n\\nSo here\\'s to Java, the versatile, the grand, \\nA language that shapes the digital land. \\nMay it continue to evolve, to grow, to inspire, \\nIn the endless quest of turning thoughts into digital fire.', additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 345, 'prompt_tokens': 33, 'total_tokens': 378, 'completion_tokens_details': None, 'prompt_token_details': None}, 'model_name': 'palmyra-x-004', 'system_fingerprint': 'v1', 'finish_reason': 'stop'}, id='run-a5b4be59-0eb0-41bd-80f7-72477861b0bd-0')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 8 }, { "cell_type": "markdown", @@ -251,10 +363,13 @@ }, { "cell_type": "code", - "execution_count": 6, "id": "b7ea7690-ec7a-4337-b392-e87d1f39a6ec", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:50.891937Z", + "start_time": "2024-11-14T09:46:50.733463Z" + } + }, "source": [ "from pydantic import BaseModel, Field\n", "\n", @@ -266,20 +381,26 @@ "\n", "\n", "llm_with_tools = llm.bind_tools([GetWeather])" - ] + ], + "outputs": [], + "execution_count": 9 }, { "cell_type": "code", - "execution_count": null, "id": "1d1ab955-6a68-42f8-bb5d-86eb1111478a", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:51.725422Z", + "start_time": "2024-11-14T09:46:50.904699Z" + } + }, "source": [ "ai_msg = llm_with_tools.invoke(\n", " \"what is the weather like in New York City\",\n", - ")\n", - "ai_msg" - ] + ")" + ], + "outputs": [], + "execution_count": 10 }, { "cell_type": "markdown", @@ -292,13 +413,30 @@ }, { "cell_type": "code", - "execution_count": null, "id": "166cb7ce-831d-4a7c-9721-abc107f11084", - "metadata": {}, - "outputs": [], - "source": [ - "ai_msg.tool_calls" - ] + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T09:46:51.744202Z", + "start_time": "2024-11-14T09:46:51.738431Z" + } + }, + "source": "print(ai_msg.tool_calls)", + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'GetWeather',\n", + " 'args': {'location': 'New York City, NY'},\n", + " 'id': 'chatcmpl-tool-fe70912c800d40fc8700d604d4823001',\n", + " 'type': 'tool_call'}]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 11 }, { "cell_type": "markdown", diff --git a/docs/docs/integrations/llms/writer.ipynb b/docs/docs/integrations/llms/writer.ipynb index 7488eff3efe16..bc17ba76582dd 100644 --- a/docs/docs/integrations/llms/writer.ipynb +++ b/docs/docs/integrations/llms/writer.ipynb @@ -4,120 +4,161 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Writer\n", + "# Writer LLM\n", "\n", "[Writer](https://writer.com/) is a platform to generate different language content.\n", "\n", "This example goes over how to use LangChain to interact with `Writer` [models](https://dev.writer.com/docs/models).\n", "\n", - "You have to get the WRITER_API_KEY [here](https://dev.writer.com/docs)." + "## Setup\n", + "\n", + "To access Writer models you'll need to create a Writer account, get an API key, and install the `writer-sdk` and `langchain-community` packages.\n", + "\n", + "### Credentials\n", + "\n", + "Head to [Writer AI Studio](https://app.writer.com/aistudio/signup?utm_campaign=devrel) to sign up to OpenAI and generate an API key. Once you've done this set the WRITER_API_KEY environment variable:" ] }, { + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-14T11:10:46.824961Z", + "start_time": "2024-11-14T11:10:44.864137Z" + } + }, "cell_type": "code", - "execution_count": 4, + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if not os.environ.get(\"WRITER_API_KEY\"):\n", + " os.environ[\"WRITER_API_KEY\"] = getpass.getpass(\"Enter your Writer API key:\")" + ], + "outputs": [], + "execution_count": 1 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Installation\n", + "\n", + "The LangChain Writer integration lives in the `langchain-community` package:" + ] + }, + { "metadata": { - "tags": [] + "ExecuteTime": { + "end_time": "2024-11-14T11:10:48.297429Z", + "start_time": "2024-11-14T11:10:46.843983Z" + } }, + "cell_type": "code", + "source": "%pip install -qU langchain-community writer-sdk", "outputs": [ { - "name": "stdin", + "name": "stdout", "output_type": "stream", "text": [ - " ········\n" + "\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m24.2\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m24.3.1\u001B[0m\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n", + "Note: you may need to restart the kernel to use updated packages.\n" ] } ], - "source": [ - "from getpass import getpass\n", - "\n", - "WRITER_API_KEY = getpass()" - ] + "execution_count": 2 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Now we can initialize our model object to interact with writer LLMs" }, { - "cell_type": "code", - "execution_count": 5, "metadata": { - "tags": [] + "ExecuteTime": { + "end_time": "2024-11-14T11:10:49.818902Z", + "start_time": "2024-11-14T11:10:48.580516Z" + } }, - "outputs": [], + "cell_type": "code", "source": [ - "import os\n", + "from langchain_community.llms import Writer as WriterLLM\n", "\n", - "os.environ[\"WRITER_API_KEY\"] = WRITER_API_KEY" - ] + "llm = WriterLLM(\n", + " temperature=0.7,\n", + " max_tokens=1000,\n", + " # other params...\n", + ")" + ], + "outputs": [], + "execution_count": 3 }, { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from langchain.chains import LLMChain\n", - "from langchain_community.llms import Writer\n", - "from langchain_core.prompts import PromptTemplate" - ] + "metadata": {}, + "cell_type": "markdown", + "source": "## Invocation" }, { - "cell_type": "code", - "execution_count": 7, "metadata": { - "tags": [] + "jupyter": { + "is_executing": true + }, + "ExecuteTime": { + "start_time": "2024-11-14T11:10:49.832822Z" + } }, + "cell_type": "code", + "source": "response_text = llm.invoke(input=\"Write a poem\")", "outputs": [], - "source": [ - "template = \"\"\"Question: {question}\n", - "\n", - "Answer: Let's think step by step.\"\"\"\n", - "\n", - "prompt = PromptTemplate.from_template(template)" - ] + "execution_count": null }, { + "metadata": {}, "cell_type": "code", - "execution_count": 14, - "metadata": { - "tags": [] - }, + "source": "print(response_text)", "outputs": [], - "source": [ - "# If you get an error, probably, you need to set up the \"base_url\" parameter that can be taken from the error log.\n", - "\n", - "llm = Writer()" - ] + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Streaming" }, { + "metadata": {}, "cell_type": "code", - "execution_count": 15, - "metadata": { - "tags": [] - }, + "source": "stream_response = llm.stream(input=\"Tell me a fairytale\")", "outputs": [], - "source": [ - "llm_chain = LLMChain(prompt=prompt, llm=llm)" - ] + "execution_count": null }, { + "metadata": {}, "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, + "source": [ + "for chunk in stream_response:\n", + " print(chunk, end=\"\")" + ], "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", "source": [ - "question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n", + "## Async\n", "\n", - "llm_chain.run(question)" + "Writer support asynchronous calls via **ainvoke()** and **astream()** methods" ] }, { - "cell_type": "code", - "execution_count": null, "metadata": {}, - "outputs": [], - "source": [] + "cell_type": "markdown", + "source": [ + "## API reference\n", + "\n", + "For detailed documentation of all Writer features, head to our [API reference](https://dev.writer.com/api-guides/api-reference/completion-api/text-generation#text-generation)." + ] } ], "metadata": { diff --git a/libs/community/langchain_community/chat_models/writer.py b/libs/community/langchain_community/chat_models/writer.py index 945b9d8b0b6d2..4101b6e23eb35 100644 --- a/libs/community/langchain_community/chat_models/writer.py +++ b/libs/community/langchain_community/chat_models/writer.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import logging from typing import ( Any, @@ -11,7 +12,6 @@ Iterator, List, Literal, - Mapping, Optional, Sequence, Tuple, @@ -26,8 +26,6 @@ from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, - agenerate_from_stream, - generate_from_stream, ) from langchain_core.messages import ( AIMessage, @@ -40,99 +38,49 @@ ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable +from langchain_core.utils import get_from_dict_or_env from langchain_core.utils.function_calling import convert_to_openai_tool -from pydantic import BaseModel, ConfigDict, Field, SecretStr +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator logger = logging.getLogger(__name__) -def _convert_message_to_dict(message: BaseMessage) -> dict: - """Convert a LangChain message to a Writer message dict.""" - message_dict = {"role": "", "content": message.content} - - if isinstance(message, ChatMessage): - message_dict["role"] = message.role - elif isinstance(message, HumanMessage): - message_dict["role"] = "user" - elif isinstance(message, AIMessage): - message_dict["role"] = "assistant" - if message.tool_calls: - message_dict["tool_calls"] = [ - { - "id": tool["id"], - "type": "function", - "function": {"name": tool["name"], "arguments": tool["args"]}, - } - for tool in message.tool_calls - ] - elif isinstance(message, SystemMessage): - message_dict["role"] = "system" - elif isinstance(message, ToolMessage): - message_dict["role"] = "tool" - message_dict["tool_call_id"] = message.tool_call_id - else: - raise ValueError(f"Got unknown message type: {type(message)}") - - if message.name: - message_dict["name"] = message.name - - return message_dict - - -def _convert_dict_to_message(response_dict: Dict[str, Any]) -> BaseMessage: - """Convert a Writer message dict to a LangChain message.""" - role = response_dict["role"] - content = response_dict.get("content", "") - - if role == "user": - return HumanMessage(content=content) - elif role == "assistant": - additional_kwargs = {} - if tool_calls := response_dict.get("tool_calls"): - additional_kwargs["tool_calls"] = tool_calls - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) - elif role == "system": - return SystemMessage(content=content) - elif role == "tool": - return ToolMessage( - content=content, - tool_call_id=response_dict["tool_call_id"], - name=response_dict.get("name"), - ) - else: - return ChatMessage(content=content, role=role) - - class ChatWriter(BaseChatModel): """Writer chat model. To use, you should have the ``writer-sdk`` Python package installed, and the - environment variable ``WRITER_API_KEY`` set with your API key. + environment variable ``WRITER_API_KEY`` set with your API key or pass 'api_key' + init param. Example: .. code-block:: python from langchain_community.chat_models import ChatWriter - chat = ChatWriter(model="palmyra-x-004") + chat = ChatWriter( + api_key="your key" + model="palmyra-x-004" + ) """ client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private: + + api_key: Optional[SecretStr] = Field(default=None) + """Writer API key.""" + model_name: str = Field(default="palmyra-x-004", alias="model") """Model name to use.""" + temperature: float = 0.7 """What sampling temperature to use.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - writer_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") - """Writer API key.""" - writer_api_base: Optional[str] = Field(default=None, alias="base_url") - """Base URL for API requests.""" - streaming: bool = False - """Whether to stream the results or not.""" + n: int = 1 """Number of chat completions to generate for each prompt.""" + max_tokens: Optional[int] = None """Maximum number of tokens to generate.""" @@ -149,37 +97,159 @@ def _identifying_params(self) -> Dict[str, Any]: return { "model_name": self.model_name, "temperature": self.temperature, - "streaming": self.streaming, **self.model_kwargs, } - def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling Writer API.""" + return { + "model": self.model_name, + "temperature": self.temperature, + "n": self.n, + "max_tokens": self.max_tokens, + **self.model_kwargs, + } + + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: + """Validates that api key is passed and creates Writer clients.""" + try: + from writerai import AsyncClient, Client + except ImportError as e: + raise ImportError( + "Could not import writerai python package. " + "Please install it with `pip install writerai`." + ) from e + + if not values.get("client"): + values.update( + { + "client": Client( + api_key=get_from_dict_or_env( + values, "api_key", "WRITER_API_KEY" + ) + ) + } + ) + + if not values.get("async_client"): + values.update( + { + "async_client": AsyncClient( + api_key=get_from_dict_or_env( + values, "api_key", "WRITER_API_KEY" + ) + ) + } + ) + + if not ( + type(values.get("client")) is Client + and type(values.get("async_client")) is AsyncClient + ): + raise ValueError( + "'client' attribute must be with type 'Client' and " + "'async_client' must be with type 'AsyncClient' from 'writerai' package" + ) + + return values + + def _create_chat_result(self, response: Any) -> ChatResult: generations = [] - for choice in response["choices"]: - message = _convert_dict_to_message(choice["message"]) + for choice in response.choices: + message = self._convert_writer_to_langchain(choice.message) gen = ChatGeneration( message=message, - generation_info=dict(finish_reason=choice.get("finish_reason")), + generation_info=dict(finish_reason=choice.finish_reason), ) generations.append(gen) - token_usage = response.get("usage", {}) + token_usage = {} + + if response.usage: + token_usage = response.usage.__dict__ llm_output = { "token_usage": token_usage, "model_name": self.model_name, - "system_fingerprint": response.get("system_fingerprint", ""), + "system_fingerprint": response.system_fingerprint, } return ChatResult(generations=generations, llm_output=llm_output) - def _convert_messages_to_dicts( + @staticmethod + def _convert_langchain_to_writer(message: BaseMessage) -> dict: + """Convert a LangChain message to a Writer message dict.""" + message_dict = {"role": "", "content": message.content} + + if isinstance(message, ChatMessage): + message_dict["role"] = message.role + elif isinstance(message, HumanMessage): + message_dict["role"] = "user" + elif isinstance(message, AIMessage): + message_dict["role"] = "assistant" + if message.tool_calls: + message_dict["tool_calls"] = [ + { + "id": tool["id"], + "type": "function", + "function": {"name": tool["name"], "arguments": tool["args"]}, + } + for tool in message.tool_calls + ] + elif isinstance(message, SystemMessage): + message_dict["role"] = "system" + elif isinstance(message, ToolMessage): + message_dict["role"] = "tool" + message_dict["tool_call_id"] = message.tool_call_id + else: + raise ValueError(f"Got unknown message type: {type(message)}") + + if message.name: + message_dict["name"] = message.name + + return message_dict + + @staticmethod + def _convert_writer_to_langchain(response_message: Any) -> BaseMessage: + """Convert a Writer message to a LangChain message.""" + if not isinstance(response_message, dict): + response_message = json.loads( + json.dumps(response_message, default=lambda o: o.__dict__) + ) + + role = response_message.get("role", "") + content = response_message.get("content") + if not content: + content = "" + + if role == "user": + return HumanMessage(content=content) + elif role == "assistant": + additional_kwargs = {} + if tool_calls := response_message.get("tool_calls", []): + additional_kwargs["tool_calls"] = tool_calls + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=content) + elif role == "tool": + return ToolMessage( + content=content, + tool_call_id=response_message.get("tool_call_id", ""), + name=response_message.get("name", ""), + ) + else: + return ChatMessage(content=content, role=role) + + def _convert_messages_to_writer( self, messages: List[BaseMessage], stop: Optional[List[str]] = None ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """Convert a list of LangChain messages to List of Writer dicts.""" params = { "model": self.model_name, "temperature": self.temperature, "n": self.n, - "stream": self.streaming, **self.model_kwargs, } if stop: @@ -187,7 +257,7 @@ def _convert_messages_to_dicts( if self.max_tokens is not None: params["max_tokens"] = self.max_tokens - message_dicts = [_convert_message_to_dict(m) for m in messages] + message_dicts = [self._convert_langchain_to_writer(m) for m in messages] return message_dicts, params def _stream( @@ -197,17 +267,17 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - message_dicts, params = self._convert_messages_to_dicts(messages, stop) + message_dicts, params = self._convert_messages_to_writer(messages, stop) params = {**params, **kwargs, "stream": True} response = self.client.chat.chat(messages=message_dicts, **params) for chunk in response: - delta = chunk["choices"][0].get("delta") - if not delta or not delta.get("content"): + delta = chunk.choices[0].delta + if not delta or not delta.content: continue - chunk = _convert_dict_to_message( - {"role": "assistant", "content": delta["content"]} + chunk = self._convert_writer_to_langchain( + {"role": "assistant", "content": delta.content} ) chunk = ChatGenerationChunk(message=chunk) @@ -223,17 +293,17 @@ async def _astream( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - message_dicts, params = self._convert_messages_to_dicts(messages, stop) + message_dicts, params = self._convert_messages_to_writer(messages, stop) params = {**params, **kwargs, "stream": True} response = await self.async_client.chat.chat(messages=message_dicts, **params) async for chunk in response: - delta = chunk["choices"][0].get("delta") - if not delta or not delta.get("content"): + delta = chunk.choices[0].delta + if not delta or not delta.content: continue - chunk = _convert_dict_to_message( - {"role": "assistant", "content": delta["content"]} + chunk = self._convert_writer_to_langchain( + {"role": "assistant", "content": delta.content} ) chunk = ChatGenerationChunk(message=chunk) @@ -249,12 +319,7 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - if self.streaming: - return generate_from_stream( - self._stream(messages, stop, run_manager, **kwargs) - ) - - message_dicts, params = self._convert_messages_to_dicts(messages, stop) + message_dicts, params = self._convert_messages_to_writer(messages, stop) params = {**params, **kwargs} response = self.client.chat.chat(messages=message_dicts, **params) return self._create_chat_result(response) @@ -266,28 +331,11 @@ async def _agenerate( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - if self.streaming: - return await agenerate_from_stream( - self._astream(messages, stop, run_manager, **kwargs) - ) - - message_dicts, params = self._convert_messages_to_dicts(messages, stop) + message_dicts, params = self._convert_messages_to_writer(messages, stop) params = {**params, **kwargs} response = await self.async_client.chat.chat(messages=message_dicts, **params) return self._create_chat_result(response) - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling Writer API.""" - return { - "model": self.model_name, - "temperature": self.temperature, - "stream": self.streaming, - "n": self.n, - "max_tokens": self.max_tokens, - **self.model_kwargs, - } - def bind_tools( self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], diff --git a/libs/community/langchain_community/llms/writer.py b/libs/community/langchain_community/llms/writer.py index d82a346c43616..e68909d06e13e 100644 --- a/libs/community/langchain_community/llms/writer.py +++ b/libs/community/langchain_community/llms/writer.py @@ -1,108 +1,89 @@ -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional -import requests -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models.llms import LLM -from langchain_core.utils import get_from_dict_or_env, pre_init -from pydantic import ConfigDict - -from langchain_community.llms.utils import enforce_stop_tokens +from langchain_core.outputs import GenerationChunk +from langchain_core.utils import get_from_dict_or_env +from pydantic import ConfigDict, Field, SecretStr, model_validator class Writer(LLM): """Writer large language models. - To use, you should have the environment variable ``WRITER_API_KEY`` and - ``WRITER_ORG_ID`` set with your API key and organization ID respectively. + To use, you should have the ``writer-sdk`` Python package installed, and the + environment variable ``WRITER_API_KEY`` set with your API key. Example: .. code-block:: python - from langchain_community.llms import Writer - writer = Writer(model_id="palmyra-base") + from langchain_community.llms import Writer as WriterLLM + from writerai import Writer, AsyncWriter + + client = Writer() + async_client = AsyncWriter() + + chat = WriterLLM( + client=client, + async_client=async_client + ) """ - writer_org_id: Optional[str] = None - """Writer organization ID.""" + client: Any = Field(default=None, exclude=True) #: :meta private: + async_client: Any = Field(default=None, exclude=True) #: :meta private: - model_id: str = "palmyra-instruct" - """Model name to use.""" + api_key: Optional[SecretStr] = Field(default=None) + """Writer API key.""" - min_tokens: Optional[int] = None - """Minimum number of tokens to generate.""" + model_name: str = Field(default="palmyra-x-003-instruct", alias="model") + """Model name to use.""" max_tokens: Optional[int] = None - """Maximum number of tokens to generate.""" + """The maximum number of tokens that the model can generate in the response.""" - temperature: Optional[float] = None - """What sampling temperature to use.""" + temperature: Optional[float] = 0.7 + """Controls the randomness of the model's outputs. Higher values lead to more + random outputs, while lower values make the model more deterministic.""" top_p: Optional[float] = None - """Total probability mass of tokens to consider at each step.""" + """Used to control the nucleus sampling, where only the most probable tokens + with a cumulative probability of top_p are considered for sampling, providing + a way to fine-tune the randomness of predictions.""" stop: Optional[List[str]] = None - """Sequences when completion generation will stop.""" - - presence_penalty: Optional[float] = None - """Penalizes repeated tokens regardless of frequency.""" - - repetition_penalty: Optional[float] = None - """Penalizes repeated tokens according to frequency.""" + """Specifies stopping conditions for the model's output generation. This can + be an array of strings or a single string that the model will look for as a + signal to stop generating further tokens.""" best_of: Optional[int] = None - """Generates this many completions server-side and returns the "best".""" - - logprobs: bool = False - """Whether to return log probabilities.""" - - n: Optional[int] = None - """How many completions to generate.""" + """Specifies the number of completions to generate and return the best one. + Useful for generating multiple outputs and choosing the best based on some + criteria.""" - writer_api_key: Optional[str] = None - """Writer API key.""" - - base_url: Optional[str] = None - """Base url to use, if None decides based on model name.""" - - model_config = ConfigDict( - extra="forbid", - ) + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" - @pre_init - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and organization id exist in environment.""" - - writer_api_key = get_from_dict_or_env( - values, "writer_api_key", "WRITER_API_KEY" - ) - values["writer_api_key"] = writer_api_key - - writer_org_id = get_from_dict_or_env(values, "writer_org_id", "WRITER_ORG_ID") - values["writer_org_id"] = writer_org_id - - return values + model_config = ConfigDict(populate_by_name=True) @property def _default_params(self) -> Mapping[str, Any]: """Get the default parameters for calling Writer API.""" return { - "minTokens": self.min_tokens, - "maxTokens": self.max_tokens, + "max_tokens": self.max_tokens, "temperature": self.temperature, - "topP": self.top_p, + "top_p": self.top_p, "stop": self.stop, - "presencePenalty": self.presence_penalty, - "repetitionPenalty": self.repetition_penalty, - "bestOf": self.best_of, - "logprobs": self.logprobs, - "n": self.n, + "best_of": self.best_of, + **self.model_kwargs, } @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return { - **{"model_id": self.model_id, "writer_org_id": self.writer_org_id}, + "model": self.model_name, **self._default_params, } @@ -111,6 +92,51 @@ def _llm_type(self) -> str: """Return type of llm.""" return "writer" + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: + """Validates that api key is passed and creates Writer clients.""" + try: + from writerai import AsyncClient, Client + except ImportError as e: + raise ImportError( + "Could not import writerai python package. " + "Please install it with `pip install writerai`." + ) from e + + if not values.get("client"): + values.update( + { + "client": Client( + api_key=get_from_dict_or_env( + values, "api_key", "WRITER_API_KEY" + ) + ) + } + ) + + if not values.get("async_client"): + values.update( + { + "async_client": AsyncClient( + api_key=get_from_dict_or_env( + values, "api_key", "WRITER_API_KEY" + ) + ) + } + ) + + if not ( + type(values.get("client")) is Client + and type(values.get("async_client")) is AsyncClient + ): + raise ValueError( + "'client' attribute must be with type 'Client' and " + "'async_client' must be with type 'AsyncClient' from 'writerai' package" + ) + + return values + def _call( self, prompt: str, @@ -118,41 +144,54 @@ def _call( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - """Call out to Writer's completions endpoint. - - Args: - prompt: The prompt to pass into the model. - stop: Optional list of stop words to use when generating. - - Returns: - The string generated by the model. - - Example: - .. code-block:: python - - response = Writer("Tell me a joke.") - """ - if self.base_url is not None: - base_url = self.base_url - else: - base_url = ( - "https://enterprise-api.writer.com/llm" - f"/organization/{self.writer_org_id}" - f"/model/{self.model_id}/completions" - ) - params = {**self._default_params, **kwargs} - response = requests.post( - url=base_url, - headers={ - "Authorization": f"{self.writer_api_key}", - "Content-Type": "application/json", - "Accept": "application/json", - }, - json={"prompt": prompt, **params}, - ) - text = response.text + params = {**self._identifying_params, **kwargs} + if stop is not None: + params.update({"stop": stop}) + text = self.client.completions.create(prompt=prompt, **params).choices[0].text + return text + + async def _acall( + self, + prompt: str, + stop: Optional[list[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + params = {**self._identifying_params, **kwargs} if stop is not None: - # I believe this is required since the stop tokens - # are not enforced by the model parameters - text = enforce_stop_tokens(text, stop) + params.update({"stop": stop}) + response = await self.async_client.completions.create(prompt=prompt, **params) + text = response.choices[0].text return text + + def _stream( + self, + prompt: str, + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + params = {**self._identifying_params, **kwargs, "stream": True} + if stop is not None: + params.update({"stop": stop}) + response = self.client.completions.create(prompt=prompt, **params) + for chunk in response: + if run_manager: + run_manager.on_llm_new_token(chunk.value) + yield GenerationChunk(text=chunk.value) + + async def _astream( + self, + prompt: str, + stop: Optional[list[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + params = {**self._identifying_params, **kwargs, "stream": True} + if stop is not None: + params.update({"stop": stop}) + response = await self.async_client.completions.create(prompt=prompt, **params) + async for chunk in response: + if run_manager: + await run_manager.on_llm_new_token(chunk.value) + yield GenerationChunk(text=chunk.value) diff --git a/libs/community/scripts/check_pydantic.sh b/libs/community/scripts/check_pydantic.sh index 99cb222d2b26e..ca83c483d515a 100755 --- a/libs/community/scripts/check_pydantic.sh +++ b/libs/community/scripts/check_pydantic.sh @@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@field_validator)|(@pre_ini # PRs that increase the current count will not be accepted. # PRs that decrease update the code in the repository # and allow decreasing the count of are welcome! -current_count=126 +current_count=125 if [ "$count" -gt "$current_count" ]; then echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator." diff --git a/libs/community/tests/integration_tests/llms/test_writer.py b/libs/community/tests/integration_tests/llms/test_writer.py deleted file mode 100644 index db8ad809144b0..0000000000000 --- a/libs/community/tests/integration_tests/llms/test_writer.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Test Writer API wrapper.""" - -from langchain_community.llms.writer import Writer - - -def test_writer_call() -> None: - """Test valid call to Writer.""" - llm = Writer() - output = llm.invoke("Say foo:") - assert isinstance(output, str) diff --git a/libs/community/tests/unit_tests/chat_models/test_writer.py b/libs/community/tests/unit_tests/chat_models/test_writer.py index 944a9dfeaba1f..9bde10df02128 100644 --- a/libs/community/tests/unit_tests/chat_models/test_writer.py +++ b/libs/community/tests/unit_tests/chat_models/test_writer.py @@ -1,61 +1,251 @@ -"""Unit tests for Writer chat model integration.""" - import json -from typing import Any, Dict, List -from unittest.mock import AsyncMock, MagicMock, patch +from typing import Any, Dict, List, Literal, Optional, Tuple, Type +from unittest import mock +from unittest.mock import AsyncMock, MagicMock import pytest from langchain_core.callbacks.manager import CallbackManager +from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_standard_tests.unit_tests import ChatModelUnitTests from pydantic import SecretStr -from langchain_community.chat_models.writer import ChatWriter, _convert_dict_to_message +from langchain_community.chat_models.writer import ChatWriter from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +"""Classes for mocking Writer responses.""" + + +class ChoiceDelta: + def __init__(self, content: str): + self.content = content + + +class ChunkChoice: + def __init__(self, index: int, finish_reason: str, delta: ChoiceDelta): + self.index = index + self.finish_reason = finish_reason + self.delta = delta + + +class ChatCompletionChunk: + def __init__( + self, + id: str, + object: str, + created: int, + model: str, + choices: List[ChunkChoice], + ): + self.id = id + self.object = object + self.created = created + self.model = model + self.choices = choices + + +class ToolCallFunction: + def __init__(self, name: str, arguments: str): + self.name = name + self.arguments = arguments + + +class ChoiceMessageToolCall: + def __init__(self, id: str, type: str, function: ToolCallFunction): + self.id = id + self.type = type + self.function = function + + +class Usage: + def __init__( + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + ): + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + self.total_tokens = total_tokens + + +class ChoiceMessage: + def __init__( + self, + role: str, + content: str, + tool_calls: Optional[List[ChoiceMessageToolCall]] = None, + ): + self.role = role + self.content = content + self.tool_calls = tool_calls + + +class Choice: + def __init__(self, index: int, finish_reason: str, message: ChoiceMessage): + self.index = index + self.finish_reason = finish_reason + self.message = message + + +class Chat: + def __init__( + self, + id: str, + object: str, + created: int, + system_fingerprint: str, + model: str, + usage: Usage, + choices: List[Choice], + ): + self.id = id + self.object = object + self.created = created + self.system_fingerprint = system_fingerprint + self.model = model + self.usage = usage + self.choices = choices + + +@pytest.mark.requires("writerai") +class TestChatWriterCustom: + """Test case for ChatWriter""" + + @pytest.fixture(autouse=True) + def mock_unstreaming_completion(self) -> Chat: + """Fixture providing a mock API response.""" + return Chat( + id="chat-12345", + object="chat.completion", + created=1699000000, + model="palmyra-x-004", + system_fingerprint="v1", + usage=Usage(prompt_tokens=10, completion_tokens=8, total_tokens=18), + choices=[ + Choice( + index=0, + finish_reason="stop", + message=ChoiceMessage( + role="assistant", + content="Hello! How can I help you?", + ), + ) + ], + ) + + @pytest.fixture(autouse=True) + def mock_tool_call_choice_response(self) -> Chat: + return Chat( + id="chat-12345", + object="chat.completion", + created=1699000000, + model="palmyra-x-004", + system_fingerprint="v1", + usage=Usage(prompt_tokens=29, completion_tokens=32, total_tokens=61), + choices=[ + Choice( + index=0, + finish_reason="tool_calls", + message=ChoiceMessage( + role="assistant", + content="", + tool_calls=[ + ChoiceMessageToolCall( + id="call_abc123", + type="function", + function=ToolCallFunction( + name="GetWeather", + arguments='{"location": "London"}', + ), + ) + ], + ), + ) + ], + ) + + @pytest.fixture(autouse=True) + def mock_streaming_chunks(self) -> List[ChatCompletionChunk]: + """Fixture providing mock streaming response chunks.""" + return [ + ChatCompletionChunk( + id="chat-12345", + object="chat.completion", + created=1699000000, + model="palmyra-x-004", + choices=[ + ChunkChoice( + index=0, + finish_reason="stop", + delta=ChoiceDelta(content="Hello! "), + ) + ], + ), + ChatCompletionChunk( + id="chat-12345", + object="chat.completion", + created=1699000000, + model="palmyra-x-004", + choices=[ + ChunkChoice( + index=0, + finish_reason="stop", + delta=ChoiceDelta(content="How can I help you?"), + ) + ], + ), + ] -class TestChatWriter: def test_writer_model_param(self) -> None: """Test different ways to initialize the chat model.""" test_cases: List[dict] = [ - {"model_name": "palmyra-x-004", "writer_api_key": "test-key"}, - {"model": "palmyra-x-004", "writer_api_key": "test-key"}, - {"model_name": "palmyra-x-004", "writer_api_key": "test-key"}, + { + "model_name": "palmyra-x-004", + "api_key": "key", + }, + { + "model": "palmyra-x-004", + "api_key": "key", + }, + { + "model_name": "palmyra-x-004", + "api_key": "key", + }, { "model": "palmyra-x-004", - "writer_api_key": "test-key", "temperature": 0.5, + "api_key": "key", }, ] for case in test_cases: chat = ChatWriter(**case) assert chat.model_name == "palmyra-x-004" - assert chat.writer_api_key - assert chat.writer_api_key.get_secret_value() == "test-key" assert chat.temperature == (0.5 if "temperature" in case else 0.7) - def test_convert_dict_to_message_human(self) -> None: + def test_convert_writer_to_langchain_human(self) -> None: """Test converting a human message dict to a LangChain message.""" message = {"role": "user", "content": "Hello"} - result = _convert_dict_to_message(message) + result = ChatWriter._convert_writer_to_langchain(message) assert isinstance(result, HumanMessage) assert result.content == "Hello" - def test_convert_dict_to_message_ai(self) -> None: + def test_convert_writer_to_langchain_ai(self) -> None: """Test converting an AI message dict to a LangChain message.""" message = {"role": "assistant", "content": "Hello"} - result = _convert_dict_to_message(message) + result = ChatWriter._convert_writer_to_langchain(message) assert isinstance(result, AIMessage) assert result.content == "Hello" - def test_convert_dict_to_message_system(self) -> None: + def test_convert_writer_to_langchain_system(self) -> None: """Test converting a system message dict to a LangChain message.""" message = {"role": "system", "content": "You are a helpful assistant"} - result = _convert_dict_to_message(message) + result = ChatWriter._convert_writer_to_langchain(message) assert isinstance(result, SystemMessage) assert result.content == "You are a helpful assistant" - def test_convert_dict_to_message_tool_call(self) -> None: + def test_convert_writer_to_langchain_tool_call(self) -> None: """Test converting a tool call message dict to a LangChain message.""" content = json.dumps({"result": 42}) message = { @@ -64,12 +254,12 @@ def test_convert_dict_to_message_tool_call(self) -> None: "content": content, "tool_call_id": "call_abc123", } - result = _convert_dict_to_message(message) + result = ChatWriter._convert_writer_to_langchain(message) assert isinstance(result, ToolMessage) assert result.name == "get_number" assert result.content == content - def test_convert_dict_to_message_with_tool_calls(self) -> None: + def test_convert_writer_to_langchain_with_tool_calls(self) -> None: """Test converting an AIMessage with tool calls.""" message = { "role": "assistant", @@ -85,131 +275,55 @@ def test_convert_dict_to_message_with_tool_calls(self) -> None: } ], } - result = _convert_dict_to_message(message) + result = ChatWriter._convert_writer_to_langchain(message) assert isinstance(result, AIMessage) assert result.tool_calls assert len(result.tool_calls) == 1 assert result.tool_calls[0]["name"] == "get_weather" assert result.tool_calls[0]["args"]["location"] == "London" - @pytest.fixture(autouse=True) - def mock_completion(self) -> Dict[str, Any]: - """Fixture providing a mock API response.""" - return { - "id": "chat-12345", - "object": "chat.completion", - "created": 1699000000, - "model": "palmyra-x-004", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Hello! How can I help you?", - }, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}, - } - - @pytest.fixture(autouse=True) - def mock_response(self) -> Dict[str, Any]: - response = { - "id": "chat-12345", - "choices": [ - { - "message": { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_abc123", - "type": "function", - "function": { - "name": "GetWeather", - "arguments": '{"location": "London"}', - }, - } - ], - }, - "finish_reason": "tool_calls", - } - ], - } - return response - - @pytest.fixture(autouse=True) - def mock_streaming_chunks(self) -> List[Dict[str, Any]]: - """Fixture providing mock streaming response chunks.""" - return [ - { - "id": "chat-12345", - "object": "chat.completion.chunk", - "created": 1699000000, - "model": "palmyra-x-004", - "choices": [ - { - "index": 0, - "delta": { - "role": "assistant", - "content": "Hello", - }, - "finish_reason": None, - } - ], - }, - { - "id": "chat-12345", - "object": "chat.completion.chunk", - "created": 1699000000, - "model": "palmyra-x-004", - "choices": [ - { - "index": 0, - "delta": { - "content": "!", - }, - "finish_reason": "stop", - } - ], - }, - ] - - def test_sync_completion(self, mock_completion: Dict[str, Any]) -> None: + def test_sync_completion( + self, mock_unstreaming_completion: List[ChatCompletionChunk] + ) -> None: """Test basic chat completion with mocked response.""" - chat = ChatWriter(api_key=SecretStr("test-key")) + chat = ChatWriter(api_key=SecretStr("key")) + mock_client = MagicMock() - mock_client.chat.chat.return_value = mock_completion + mock_client.chat.chat.return_value = mock_unstreaming_completion - with patch.object(chat, "client", mock_client): + with mock.patch.object(chat, "client", mock_client): message = HumanMessage(content="Hi there!") response = chat.invoke([message]) assert isinstance(response, AIMessage) assert response.content == "Hello! How can I help you?" - async def test_async_completion(self, mock_completion: Dict[str, Any]) -> None: + @pytest.mark.asyncio + async def test_async_completion( + self, mock_unstreaming_completion: List[ChatCompletionChunk] + ) -> None: """Test async chat completion with mocked response.""" - chat = ChatWriter(api_key=SecretStr("test-key")) - mock_client = AsyncMock() - mock_client.chat.chat.return_value = mock_completion + chat = ChatWriter(api_key=SecretStr("key")) + + mock_async_client = AsyncMock() + mock_async_client.chat.chat.return_value = mock_unstreaming_completion - with patch.object(chat, "async_client", mock_client): + with mock.patch.object(chat, "async_client", mock_async_client): message = HumanMessage(content="Hi there!") response = await chat.ainvoke([message]) assert isinstance(response, AIMessage) assert response.content == "Hello! How can I help you?" - def test_sync_streaming(self, mock_streaming_chunks: List[Dict[str, Any]]) -> None: + def test_sync_streaming( + self, mock_streaming_chunks: List[ChatCompletionChunk] + ) -> None: """Test sync streaming with callback handler.""" callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) chat = ChatWriter( - streaming=True, + api_key=SecretStr("key"), callback_manager=callback_manager, max_tokens=10, - api_key=SecretStr("test-key"), ) mock_client = MagicMock() @@ -217,42 +331,46 @@ def test_sync_streaming(self, mock_streaming_chunks: List[Dict[str, Any]]) -> No mock_response.__iter__.return_value = mock_streaming_chunks mock_client.chat.chat.return_value = mock_response - with patch.object(chat, "client", mock_client): + with mock.patch.object(chat, "client", mock_client): message = HumanMessage(content="Hi") - response = chat.invoke([message]) - - assert isinstance(response, AIMessage) + response = chat.stream([message]) + response_message = "" + for chunk in response: + response_message += str(chunk.content) assert callback_handler.llm_streams > 0 - assert response.content == "Hello!" + assert response_message == "Hello! How can I help you?" + @pytest.mark.asyncio async def test_async_streaming( - self, mock_streaming_chunks: List[Dict[str, Any]] + self, mock_streaming_chunks: List[ChatCompletionChunk] ) -> None: """Test async streaming with callback handler.""" callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) chat = ChatWriter( - streaming=True, + api_key=SecretStr("key"), callback_manager=callback_manager, max_tokens=10, - api_key=SecretStr("test-key"), ) - mock_client = AsyncMock() + mock_async_client = AsyncMock() mock_response = AsyncMock() mock_response.__aiter__.return_value = mock_streaming_chunks - mock_client.chat.chat.return_value = mock_response + mock_async_client.chat.chat.return_value = mock_response - with patch.object(chat, "async_client", mock_client): + with mock.patch.object(chat, "async_client", mock_async_client): message = HumanMessage(content="Hi") - response = await chat.ainvoke([message]) - - assert isinstance(response, AIMessage) + response = chat.astream([message]) + response_message = "" + async for chunk in response: + response_message += str(chunk.content) assert callback_handler.llm_streams > 0 - assert response.content == "Hello!" + assert response_message == "Hello! How can I help you?" - def test_sync_tool_calling(self, mock_response: Dict[str, Any]) -> None: + def test_sync_tool_calling( + self, mock_tool_call_choice_response: Dict[str, Any] + ) -> None: """Test synchronous tool calling functionality.""" from pydantic import BaseModel, Field @@ -261,23 +379,27 @@ class GetWeather(BaseModel): location: str = Field(..., description="The location to get weather for") - mock_client = MagicMock() - mock_client.chat.chat.return_value = mock_response + chat = ChatWriter(api_key=SecretStr("key")) - chat = ChatWriter(api_key=SecretStr("test-key"), client=mock_client) + mock_client = MagicMock() + mock_client.chat.chat.return_value = mock_tool_call_choice_response chat_with_tools = chat.bind_tools( tools=[GetWeather], tool_choice="GetWeather", ) - response = chat_with_tools.invoke("What's the weather in London?") - assert isinstance(response, AIMessage) - assert response.tool_calls - assert response.tool_calls[0]["name"] == "GetWeather" - assert response.tool_calls[0]["args"]["location"] == "London" + with mock.patch.object(chat, "client", mock_client): + response = chat_with_tools.invoke("What's the weather in London?") + assert isinstance(response, AIMessage) + assert response.tool_calls + assert response.tool_calls[0]["name"] == "GetWeather" + assert response.tool_calls[0]["args"]["location"] == "London" - async def test_async_tool_calling(self, mock_response: Dict[str, Any]) -> None: + @pytest.mark.asyncio + async def test_async_tool_calling( + self, mock_tool_call_choice_response: Dict[str, Any] + ) -> None: """Test asynchronous tool calling functionality.""" from pydantic import BaseModel, Field @@ -286,18 +408,101 @@ class GetWeather(BaseModel): location: str = Field(..., description="The location to get weather for") - mock_client = AsyncMock() - mock_client.chat.chat.return_value = mock_response + mock_async_client = AsyncMock() + mock_async_client.chat.chat.return_value = mock_tool_call_choice_response - chat = ChatWriter(api_key=SecretStr("test-key"), async_client=mock_client) + chat = ChatWriter(api_key=SecretStr("key")) chat_with_tools = chat.bind_tools( tools=[GetWeather], tool_choice="GetWeather", ) - response = await chat_with_tools.ainvoke("What's the weather in London?") - assert isinstance(response, AIMessage) - assert response.tool_calls - assert response.tool_calls[0]["name"] == "GetWeather" - assert response.tool_calls[0]["args"]["location"] == "London" + with mock.patch.object(chat, "async_client", mock_async_client): + response = await chat_with_tools.ainvoke("What's the weather in London?") + assert isinstance(response, AIMessage) + assert response.tool_calls + assert response.tool_calls[0]["name"] == "GetWeather" + assert response.tool_calls[0]["args"]["location"] == "London" + + +@pytest.mark.requires("writerai") +class TestChatWriterStandart(ChatModelUnitTests): + """Test case for ChatWriter that inherits from standard LangChain tests.""" + + @property + def chat_model_class(self) -> Type[BaseChatModel]: + """Return ChatWriter model class.""" + return ChatWriter + + @property + def chat_model_params(self) -> Dict: + """Return any additional parameters needed.""" + return { + "api_key": "fake-api-key", + "model_name": "palmyra-x-004", + } + + @property + def has_tool_calling(self) -> bool: + """Writer supports tool/function calling.""" + return True + + @property + def tool_choice_value(self) -> Optional[str]: + """Value to use for tool choice in tests.""" + return "auto" + + @property + def has_structured_output(self) -> bool: + """Writer does not yet support structured output.""" + return False + + @property + def supports_image_inputs(self) -> bool: + """Writer does not support image inputs.""" + return False + + @property + def supports_video_inputs(self) -> bool: + """Writer does not support video inputs.""" + return False + + @property + def returns_usage_metadata(self) -> bool: + """Writer returns token usage information.""" + return True + + @property + def supports_anthropic_inputs(self) -> bool: + """Writer does not support anthropic inputs.""" + return False + + @property + def supports_image_tool_message(self) -> bool: + """Writer does not support image tool message.""" + return False + + @property + def supported_usage_metadata_details( + self, + ) -> Dict[ + Literal["invoke", "stream"], + List[ + Literal[ + "audio_input", + "audio_output", + "reasoning_output", + "cache_read_input", + "cache_creation_input", + ] + ], + ]: + """Return which types of usage metadata your model supports.""" + return {"invoke": ["cache_creation_input"], "stream": ["reasoning_output"]} + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + """Return env vars, init args, and expected instance attrs for initializing + from env vars.""" + return {"WRITER_API_KEY": "key"}, {"api_key": "key"}, {"api_key": "key"} diff --git a/libs/community/tests/unit_tests/llms/test_writer.py b/libs/community/tests/unit_tests/llms/test_writer.py new file mode 100644 index 0000000000000..ffdee04db0796 --- /dev/null +++ b/libs/community/tests/unit_tests/llms/test_writer.py @@ -0,0 +1,202 @@ +from typing import List +from unittest import mock +from unittest.mock import AsyncMock, MagicMock + +import pytest +from langchain_core.callbacks import CallbackManager +from pydantic import SecretStr + +from langchain_community.llms.writer import Writer +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + +"""Classes for mocking Writer responses.""" + + +class Choice: + def __init__(self, text: str): + self.text = text + + +class Completion: + def __init__(self, choices: List[Choice]): + self.choices = choices + + +class StreamingData: + def __init__(self, value: str): + self.value = value + + +@pytest.mark.requires("writerai") +class TestWriterLLM: + """Unit tests for Writer LLM integration.""" + + @pytest.fixture(autouse=True) + def mock_unstreaming_completion(self) -> Completion: + """Fixture providing a mock API response.""" + return Completion(choices=[Choice(text="Hello! How can I help you?")]) + + @pytest.fixture(autouse=True) + def mock_streaming_completion(self) -> List[StreamingData]: + """Fixture providing mock streaming response chunks.""" + return [ + StreamingData(value="Hello! "), + StreamingData(value="How can I"), + StreamingData(value=" help you?"), + ] + + def test_sync_unstream_completion( + self, mock_unstreaming_completion: Completion + ) -> None: + """Test basic llm call with mocked response.""" + mock_client = MagicMock() + mock_client.completions.create.return_value = mock_unstreaming_completion + + llm = Writer(api_key=SecretStr("key")) + + with mock.patch.object(llm, "client", mock_client): + response_text = llm.invoke(input="Hello") + + assert response_text == "Hello! How can I help you?" + + def test_sync_unstream_completion_with_params( + self, mock_unstreaming_completion: Completion + ) -> None: + """Test llm call with passed params with mocked response.""" + mock_client = MagicMock() + mock_client.completions.create.return_value = mock_unstreaming_completion + + llm = Writer(api_key=SecretStr("key"), temperature=1) + + with mock.patch.object(llm, "client", mock_client): + response_text = llm.invoke(input="Hello") + + assert response_text == "Hello! How can I help you?" + + @pytest.mark.asyncio + async def test_async_unstream_completion( + self, mock_unstreaming_completion: Completion + ) -> None: + """Test async chat completion with mocked response.""" + mock_async_client = AsyncMock() + mock_async_client.completions.create.return_value = mock_unstreaming_completion + + llm = Writer(api_key=SecretStr("key")) + + with mock.patch.object(llm, "async_client", mock_async_client): + response_text = await llm.ainvoke(input="Hello") + + assert response_text == "Hello! How can I help you?" + + @pytest.mark.asyncio + async def test_async_unstream_completion_with_params( + self, mock_unstreaming_completion: Completion + ) -> None: + """Test async llm call with passed params with mocked response.""" + mock_async_client = AsyncMock() + mock_async_client.completions.create.return_value = mock_unstreaming_completion + + llm = Writer(api_key=SecretStr("key"), temperature=1) + + with mock.patch.object(llm, "async_client", mock_async_client): + response_text = await llm.ainvoke(input="Hello") + + assert response_text == "Hello! How can I help you?" + + def test_sync_streaming_completion( + self, mock_streaming_completion: List[StreamingData] + ) -> None: + """Test sync streaming.""" + + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.__iter__.return_value = mock_streaming_completion + mock_client.completions.create.return_value = mock_response + + llm = Writer(api_key=SecretStr("key")) + + with mock.patch.object(llm, "client", mock_client): + response = llm.stream(input="Hello") + + response_message = "" + for chunk in response: + response_message += chunk + + assert response_message == "Hello! How can I help you?" + + def test_sync_streaming_completion_with_callback_handler( + self, mock_streaming_completion: List[StreamingData] + ) -> None: + """Test sync streaming with callback handler.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.__iter__.return_value = mock_streaming_completion + mock_client.completions.create.return_value = mock_response + + llm = Writer( + api_key=SecretStr("key"), + callback_manager=callback_manager, + ) + + with mock.patch.object(llm, "client", mock_client): + response = llm.stream(input="Hello") + + response_message = "" + for chunk in response: + response_message += chunk + + assert callback_handler.llm_streams == 3 + assert response_message == "Hello! How can I help you?" + + @pytest.mark.asyncio + async def test_async_streaming_completion( + self, mock_streaming_completion: Completion + ) -> None: + """Test async streaming with callback handler.""" + + mock_async_client = AsyncMock() + mock_response = AsyncMock() + mock_response.__aiter__.return_value = mock_streaming_completion + mock_async_client.completions.create.return_value = mock_response + + llm = Writer(api_key=SecretStr("key")) + + with mock.patch.object(llm, "async_client", mock_async_client): + response = llm.astream(input="Hello") + + response_message = "" + async for chunk in response: + response_message += str(chunk) + + assert response_message == "Hello! How can I help you?" + + @pytest.mark.asyncio + async def test_async_streaming_completion_with_callback_handler( + self, mock_streaming_completion: Completion + ) -> None: + """Test async streaming with callback handler.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + + mock_async_client = AsyncMock() + mock_response = AsyncMock() + mock_response.__aiter__.return_value = mock_streaming_completion + mock_async_client.completions.create.return_value = mock_response + + llm = Writer( + api_key=SecretStr("key"), + callback_manager=callback_manager, + ) + + with mock.patch.object(llm, "async_client", mock_async_client): + response = llm.astream(input="Hello") + + response_message = "" + async for chunk in response: + response_message += str(chunk) + + assert callback_handler.llm_streams == 3 + assert response_message == "Hello! How can I help you?"