diff --git a/docs/docs/integrations/chat/wasm_chat.ipynb b/docs/docs/integrations/chat/wasm_chat.ipynb new file mode 100644 index 0000000000000..ed7859d3b97a1 --- /dev/null +++ b/docs/docs/integrations/chat/wasm_chat.ipynb @@ -0,0 +1,85 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Wasm Chat\n", + "\n", + "`Wasm-chat` allows you to chat with LLMs of [GGUF](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/README.md) format both locally and via chat service.\n", + "\n", + "- `WasmChatService` provides developers an OpenAI API compatible service to chat with LLMs via HTTP requests.\n", + "\n", + "- `WasmChatLocal` enables developers to chat with LLMs locally (coming soon).\n", + "\n", + "Both `WasmChatService` and `WasmChatLocal` run on the infrastructure driven by [WasmEdge Runtime](https://wasmedge.org/), which provides a lightweight and portable WebAssembly container environment for LLM inference tasks.\n", + "\n", + "## Chat via API Service\n", + "\n", + "`WasmChatService` provides chat services by the `llama-api-server`. Following the steps in [llama-api-server quick-start](https://github.com/second-state/llama-utils/tree/main/api-server#readme), you can host your own API service so that you can chat with any models you like on any device you have anywhere as long as the internet is available." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models.wasm_chat import WasmChatService\n", + "from langchain_core.messages import AIMessage, HumanMessage, SystemMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Bot] Paris\n" + ] + } + ], + "source": [ + "# service url\n", + "service_url = \"https://b008-54-186-154-209.ngrok-free.app\"\n", + "\n", + "# create wasm-chat service instance\n", + "chat = WasmChatService(service_url=service_url)\n", + "\n", + "# create message sequence\n", + "system_message = SystemMessage(content=\"You are an AI assistant\")\n", + "user_message = HumanMessage(content=\"What is the capital of France?\")\n", + "messages = [system_message, user_message]\n", + "\n", + "# chat with wasm-chat service\n", + "response = chat(messages)\n", + "\n", + "print(f\"[Bot] {response.content}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index bdb50fd29f0e4..d93698e2a9a79 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -47,9 +47,11 @@ from langchain_community.chat_models.promptlayer_openai import PromptLayerChatOpenAI from langchain_community.chat_models.vertexai import ChatVertexAI from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat +from langchain_community.chat_models.wasm_chat import WasmChatService from langchain_community.chat_models.yandex import ChatYandexGPT __all__ = [ + "WasmChatService", "ChatOpenAI", "BedrockChat", "AzureChatOpenAI", diff --git a/libs/community/langchain_community/chat_models/wasm_chat.py b/libs/community/langchain_community/chat_models/wasm_chat.py new file mode 100644 index 0000000000000..fb0826502c31b --- /dev/null +++ b/libs/community/langchain_community/chat_models/wasm_chat.py @@ -0,0 +1,144 @@ +import json +import logging +from typing import Any, Dict, List, Mapping, Optional + +import requests +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.pydantic_v1 import root_validator +from langchain_core.utils import get_pydantic_field_names + +logger = logging.getLogger(__name__) + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + return AIMessage(content=_dict.get("content", "") or "") + else: + return ChatMessage(content=_dict["content"], role=role) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + message_dict: Dict[str, Any] + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + else: + raise TypeError(f"Got unknown type {message}") + + return message_dict + + +class WasmChatService(BaseChatModel): + """Chat with LLMs via `llama-api-server` + + For the information about `llama-api-server`, visit https://github.com/second-state/llama-utils + """ + + request_timeout: int = 60 + """request timeout for chat http requests""" + service_url: Optional[str] = None + """URL of WasmChat service""" + model: str = "NA" + """model name, default is `NA`.""" + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + return values + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + res = self._chat(messages, **kwargs) + + if res.status_code != 200: + raise ValueError(f"Error code: {res.status_code}, reason: {res.reason}") + + response = res.json() + + return self._create_chat_result(response) + + def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: + if self.service_url is None: + res = requests.models.Response() + res.status_code = 503 + res.reason = "The IP address or port of the chat service is incorrect." + return res + + service_url = f"{self.service_url}/v1/chat/completions" + + payload = { + "model": self.model, + "messages": [_convert_message_to_dict(m) for m in messages], + } + + res = requests.post( + url=service_url, + timeout=self.request_timeout, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + }, + data=json.dumps(payload), + ) + + return res + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + message = _convert_dict_to_message(response["choices"][0].get("message")) + generations = [ChatGeneration(message=message)] + + token_usage = response["usage"] + llm_output = {"token_usage": token_usage, "model": self.model} + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _llm_type(self) -> str: + return "wasm-chat" diff --git a/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py b/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py new file mode 100644 index 0000000000000..369908a2d263f --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_wasm_chat.py @@ -0,0 +1,28 @@ +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from langchain_community.chat_models.wasm_chat import WasmChatService + + +@pytest.mark.enable_socket +def test_chat_wasm_service() -> None: + """This test requires the port 8080 is not occupied.""" + + # service url + service_url = "https://b008-54-186-154-209.ngrok-free.app" + + # create wasm-chat service instance + chat = WasmChatService(service_url=service_url) + + # create message sequence + system_message = SystemMessage(content="You are an AI assistant") + user_message = HumanMessage(content="What is the capital of France?") + messages = [system_message, user_message] + + # chat with wasm-chat service + response = chat(messages) + + # check response + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "Paris" in response.content diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index 020c160121606..e49353c1e54c6 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -31,6 +31,7 @@ "ChatHunyuan", "GigaChat", "VolcEngineMaasChat", + "WasmChatService", "GPTRouter", ] diff --git a/libs/community/tests/unit_tests/chat_models/test_wasmchat.py b/libs/community/tests/unit_tests/chat_models/test_wasmchat.py new file mode 100644 index 0000000000000..6905130de6753 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_wasmchat.py @@ -0,0 +1,78 @@ +import pytest +from langchain_core.messages import ( + AIMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) + +from langchain_community.chat_models.wasm_chat import ( + WasmChatService, + _convert_dict_to_message, + _convert_message_to_dict, +) + + +def test__convert_message_to_dict_human() -> None: + message = HumanMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "user", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_ai() -> None: + message = AIMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "assistant", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_system() -> None: + message = SystemMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "system", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_function() -> None: + message = FunctionMessage(name="foo", content="bar") + with pytest.raises(TypeError) as e: + _convert_message_to_dict(message) + assert "Got unknown type" in str(e) + + +def test__convert_dict_to_message_human() -> None: + message_dict = {"role": "user", "content": "foo"} + result = _convert_dict_to_message(message_dict) + expected_output = HumanMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_ai() -> None: + message_dict = {"role": "assistant", "content": "foo"} + result = _convert_dict_to_message(message_dict) + expected_output = AIMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_other_role() -> None: + message_dict = {"role": "system", "content": "foo"} + result = _convert_dict_to_message(message_dict) + expected_output = ChatMessage(role="system", content="foo") + assert result == expected_output + + +def test_wasm_chat_without_service_url() -> None: + chat = WasmChatService() + + # create message sequence + system_message = SystemMessage(content="You are an AI assistant") + user_message = HumanMessage(content="What is the capital of France?") + messages = [system_message, user_message] + + with pytest.raises(ValueError) as e: + chat(messages) + + assert "Error code: 503" in str(e) + assert "reason: The IP address or port of the chat service is incorrect." in str(e)