From 6258f9fbcd36a63256a07290d748fabcd5937c54 Mon Sep 17 00:00:00 2001 From: Xin Liu Date: Sat, 16 Dec 2023 01:03:30 +0800 Subject: [PATCH 1/6] feat: new integration `wasm_chat` Signed-off-by: Xin Liu --- docs/docs/integrations/chat/wasm_chat.ipynb | 85 +++++++++++ .../chat_models/__init__.py | 2 + .../chat_models/wasm_chat.py | 144 ++++++++++++++++++ .../chat_models/test_wasm_chat.py | 28 ++++ 4 files changed, 259 insertions(+) create mode 100644 docs/docs/integrations/chat/wasm_chat.ipynb create mode 100644 libs/community/langchain_community/chat_models/wasm_chat.py create mode 100644 libs/community/tests/integration_tests/chat_models/test_wasm_chat.py diff --git a/docs/docs/integrations/chat/wasm_chat.ipynb b/docs/docs/integrations/chat/wasm_chat.ipynb new file mode 100644 index 0000000000000..297b676027a4f --- /dev/null +++ b/docs/docs/integrations/chat/wasm_chat.ipynb @@ -0,0 +1,85 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Wasm Chat\n", + "\n", + "`Wasm-chat` allows you to chat with LLMs of [GGUF](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/README.md) format both locally and via chat service.\n", + "\n", + "- `WasmChatService` provides developers an OpenAI API compatible service to chat with LLMs via HTTP requests.\n", + "\n", + "- `WasmChatLocal` enables developers to chat with LLMs locally (coming soon).\n", + "\n", + "Both `WasmChatService` and `WasmChatLocal` run on the infrastructure driven by [WasmEdge Runtime](https://wasmedge.org/), which provides a lightweight and portable WebAssembly container environment for LLM inference tasks.\n", + "\n", + "## Chat via API Service\n", + "\n", + "`WasmChatService` provides chat services by the `llama-api-server`. Following the steps in [llama-api-server quick-start](https://github.com/second-state/llama-utils/tree/main/api-server#readme), you can host your own API service so that you can chat with any models you like on any device you have anywhere as long as the internet is available." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models.wasm_chat import WasmChatService\n", + "from langchain_core.messages import AIMessage, HumanMessage, SystemMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Bot] Paris\n" + ] + } + ], + "source": [ + "# service url\n", + "service_url = \"https://f370-50-112-58-64.ngrok-free.app\"\n", + "\n", + "# create wasm-chat service instance\n", + "chat = WasmChatService(service_url=service_url)\n", + "\n", + "# create message sequence\n", + "system_message = SystemMessage(content=\"You are an AI assistant\")\n", + "user_message = HumanMessage(content=\"What is the capital of France?\")\n", + "messages = [system_message, user_message]\n", + "\n", + "# chat with wasm-chat service\n", + "response = chat(messages)\n", + "\n", + "print(f\"[Bot] {response.content}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 7064f7b758565..aba1a46c03877 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -46,9 +46,11 @@ from langchain_community.chat_models.promptlayer_openai import PromptLayerChatOpenAI from langchain_community.chat_models.vertexai import ChatVertexAI from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat +from langchain_community.chat_models.wasm_chat import WasmChatService from langchain_community.chat_models.yandex import ChatYandexGPT __all__ = [ + "WasmChatService", "ChatOpenAI", "BedrockChat", "AzureChatOpenAI", diff --git a/libs/community/langchain_community/chat_models/wasm_chat.py b/libs/community/langchain_community/chat_models/wasm_chat.py new file mode 100644 index 0000000000000..bc42506e430d1 --- /dev/null +++ b/libs/community/langchain_community/chat_models/wasm_chat.py @@ -0,0 +1,144 @@ +import json +import logging +from typing import Any, Dict, List, Mapping, Optional + +import requests +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.pydantic_v1 import root_validator +from langchain_core.utils import get_pydantic_field_names + +logger = logging.getLogger(__name__) + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + return AIMessage(content=_dict.get("content", "") or "") + else: + return ChatMessage(content=_dict["content"], role=role) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + message_dict: Dict[str, Any] + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + else: + raise TypeError(f"Got unknown type {message}") + + return message_dict + + +class WasmChatService(BaseChatModel): + """Chat with LLMs via `llama-api-server` + + For the information about `llama-api-server`, visit https://github.com/second-state/llama-utils + """ + + request_timeout: int = 60 + """request timeout for chat http requests""" + service_url: Optional[str] = None + """URL of WasmChat service""" + model: str = "NA" + """model name, default is `NA`.""" + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + return values + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + res = self._chat(messages, **kwargs) + + if res.status_code != 200: + raise ValueError(f"Error code: {res.status_code}, reason: {res.reason}") + + response = res.json() + + return self._create_chat_result(response) + + def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: + if self.service_url is None: + res = requests.models.Response() + res.status_code = 503 + res.reason = "The IP address or port of the chat service is incorrect." + return res + + service_url = f"{self.service_url}/v1/chat/completions" + + payload = { + "model": self.model, + "messages": [_convert_message_to_dict(m) for m in messages], + } + + res = requests.post( + url=service_url, + timeout=self.request_timeout, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + }, + data=json.dumps(payload), + ) + + return res + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + message = _convert_dict_to_message(response["choices"][0].get("message")) + generations = [ChatGeneration(message=message)] + + token_usage = response["usage"] + llm_output = {"token_usage": token_usage, "model": self.model} + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _llm_type(self) -> str: + return "wasmedge-chat" diff --git a/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py b/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py new file mode 100644 index 0000000000000..b1e4a5d42b8d1 --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py @@ -0,0 +1,28 @@ +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from langchain_community.chat_models.wasm_chat import WasmChatService + + +@pytest.mark.enable_socket +def test_chat_wasm_service() -> None: + """This test requires the port 8080 is not occupied.""" + + # service url + service_url = "https://f370-50-112-58-64.ngrok-free.app" + + # create wasm-chat service instance + chat = WasmChatService(service_url=service_url) + + # create message sequence + system_message = SystemMessage(content="You are an AI assistant") + user_message = HumanMessage(content="What is the capital of France?") + messages = [system_message, user_message] + + # chat with wasm-chat service + response = chat(messages) + + # check response + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "Paris" in response.content From 7340ddc1fcb844f86f4cce3b727327962ef83ffd Mon Sep 17 00:00:00 2001 From: Xin Liu Date: Mon, 18 Dec 2023 22:17:59 +0800 Subject: [PATCH 2/6] chore: update `wasm_chat` Signed-off-by: Xin Liu --- libs/community/langchain_community/chat_models/wasm_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/langchain_community/chat_models/wasm_chat.py b/libs/community/langchain_community/chat_models/wasm_chat.py index bc42506e430d1..fb0826502c31b 100644 --- a/libs/community/langchain_community/chat_models/wasm_chat.py +++ b/libs/community/langchain_community/chat_models/wasm_chat.py @@ -141,4 +141,4 @@ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: @property def _llm_type(self) -> str: - return "wasmedge-chat" + return "wasm-chat" From 8662853eaa51d6e2983b776df51eb7b6c1b54a54 Mon Sep 17 00:00:00 2001 From: Xin Liu Date: Mon, 18 Dec 2023 22:19:14 +0800 Subject: [PATCH 3/6] test: add unit tests for `wasm_chat` Signed-off-by: Xin Liu --- .../unit_tests/chat_models/test_imports.py | 1 + .../unit_tests/chat_models/test_wasmchat.py | 79 +++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 libs/community/tests/unit_tests/chat_models/test_wasmchat.py diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index e0cc187443913..c8429090c2ce5 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -31,6 +31,7 @@ "ChatHunyuan", "GigaChat", "VolcEngineMaasChat", + "WasmChatService", ] diff --git a/libs/community/tests/unit_tests/chat_models/test_wasmchat.py b/libs/community/tests/unit_tests/chat_models/test_wasmchat.py new file mode 100644 index 0000000000000..688131cafc4f6 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_wasmchat.py @@ -0,0 +1,79 @@ +import pytest +from langchain_community.chat_models.wasm_chat import ( + WasmChatService, + _convert_dict_to_message, + _convert_message_to_dict, +) +from langchain_core.messages import ( + AIMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) + + +def test__convert_message_to_dict_human() -> None: + message = HumanMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "user", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_ai() -> None: + message = AIMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "assistant", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_system() -> None: + message = SystemMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "system", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_function() -> None: + message = FunctionMessage(name="foo", content="bar") + with pytest.raises(TypeError) as e: + _convert_message_to_dict(message) + assert "Got unknown type" in str(e) + + +def test__convert_dict_to_message_human() -> None: + message_dict = {"role": "user", "content": "foo"} + result = _convert_dict_to_message(message_dict) + expected_output = HumanMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_ai() -> None: + message_dict = {"role": "assistant", "content": "foo"} + result = _convert_dict_to_message(message_dict) + expected_output = AIMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_other_role() -> None: + message_dict = {"role": "system", "content": "foo"} + result = _convert_dict_to_message(message_dict) + expected_output = ChatMessage(role="system", content="foo") + assert result == expected_output + + +def test_wasm_chat_without_service_url() -> None: + chat = WasmChatService() + + # create message sequence + system_message = SystemMessage(content="You are an AI assistant") + user_message = HumanMessage(content="What is the capital of France?") + messages = [system_message, user_message] + + with pytest.raises(ValueError) as e: + chat(messages) + + assert ( + "Error code: 503, reason: The IP address or port of the chat service is incorrect." + in str(e) + ) From 501a106aecb77f80898879d6613b8353f7b3f0a6 Mon Sep 17 00:00:00 2001 From: Xin Liu Date: Wed, 20 Dec 2023 21:23:39 +0800 Subject: [PATCH 4/6] fix: lint issues Signed-off-by: Xin Liu --- .../unit_tests/chat_models/test_wasmchat.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/libs/community/tests/unit_tests/chat_models/test_wasmchat.py b/libs/community/tests/unit_tests/chat_models/test_wasmchat.py index 688131cafc4f6..6905130de6753 100644 --- a/libs/community/tests/unit_tests/chat_models/test_wasmchat.py +++ b/libs/community/tests/unit_tests/chat_models/test_wasmchat.py @@ -1,9 +1,4 @@ import pytest -from langchain_community.chat_models.wasm_chat import ( - WasmChatService, - _convert_dict_to_message, - _convert_message_to_dict, -) from langchain_core.messages import ( AIMessage, ChatMessage, @@ -12,6 +7,12 @@ SystemMessage, ) +from langchain_community.chat_models.wasm_chat import ( + WasmChatService, + _convert_dict_to_message, + _convert_message_to_dict, +) + def test__convert_message_to_dict_human() -> None: message = HumanMessage(content="foo") @@ -73,7 +74,5 @@ def test_wasm_chat_without_service_url() -> None: with pytest.raises(ValueError) as e: chat(messages) - assert ( - "Error code: 503, reason: The IP address or port of the chat service is incorrect." - in str(e) - ) + assert "Error code: 503" in str(e) + assert "reason: The IP address or port of the chat service is incorrect." in str(e) From c510b056824a088d240399a367661298fe707c6a Mon Sep 17 00:00:00 2001 From: Xin Liu Date: Thu, 28 Dec 2023 22:04:19 +0800 Subject: [PATCH 5/6] chore: update service url for integration test and tutorial Signed-off-by: Xin Liu --- docs/docs/integrations/chat/wasm_chat.ipynb | 2 +- .../tests/integration_tests/chat_models/test_wasm_chat.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/docs/integrations/chat/wasm_chat.ipynb b/docs/docs/integrations/chat/wasm_chat.ipynb index 297b676027a4f..ed7859d3b97a1 100644 --- a/docs/docs/integrations/chat/wasm_chat.ipynb +++ b/docs/docs/integrations/chat/wasm_chat.ipynb @@ -44,7 +44,7 @@ ], "source": [ "# service url\n", - "service_url = \"https://f370-50-112-58-64.ngrok-free.app\"\n", + "service_url = \"https://b008-54-186-154-209.ngrok-free.app\"\n", "\n", "# create wasm-chat service instance\n", "chat = WasmChatService(service_url=service_url)\n", diff --git a/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py b/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py index b1e4a5d42b8d1..417d25683acde 100644 --- a/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py +++ b/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py @@ -1,7 +1,6 @@ import pytest -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage - from langchain_community.chat_models.wasm_chat import WasmChatService +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage @pytest.mark.enable_socket @@ -9,7 +8,7 @@ def test_chat_wasm_service() -> None: """This test requires the port 8080 is not occupied.""" # service url - service_url = "https://f370-50-112-58-64.ngrok-free.app" + service_url = "https://b008-54-186-154-209.ngrok-free.app" # create wasm-chat service instance chat = WasmChatService(service_url=service_url) From 332746a4cff6b4221f70bf30e63cff30cea34993 Mon Sep 17 00:00:00 2001 From: Xin Liu Date: Tue, 2 Jan 2024 08:58:23 +0800 Subject: [PATCH 6/6] chore: format code Signed-off-by: Xin Liu --- .../tests/integration_tests/chat_models/test_wasm_chat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py b/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py index 417d25683acde..369908a2d263f 100644 --- a/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py +++ b/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py @@ -1,7 +1,8 @@ import pytest -from langchain_community.chat_models.wasm_chat import WasmChatService from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_community.chat_models.wasm_chat import WasmChatService + @pytest.mark.enable_socket def test_chat_wasm_service() -> None: