Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new integration wasm_chat #14787

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions docs/docs/integrations/chat/wasm_chat.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Wasm Chat\n",
"\n",
"`Wasm-chat` allows you to chat with LLMs of [GGUF](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/README.md) format both locally and via chat service.\n",
"\n",
"- `WasmChatService` provides developers an OpenAI API compatible service to chat with LLMs via HTTP requests.\n",
"\n",
"- `WasmChatLocal` enables developers to chat with LLMs locally (coming soon).\n",
"\n",
"Both `WasmChatService` and `WasmChatLocal` run on the infrastructure driven by [WasmEdge Runtime](https://wasmedge.org/), which provides a lightweight and portable WebAssembly container environment for LLM inference tasks.\n",
"\n",
"## Chat via API Service\n",
"\n",
"`WasmChatService` provides chat services by the `llama-api-server`. Following the steps in [llama-api-server quick-start](https://github.com/second-state/llama-utils/tree/main/api-server#readme), you can host your own API service so that you can chat with any models you like on any device you have anywhere as long as the internet is available."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.wasm_chat import WasmChatService\n",
"from langchain_core.messages import AIMessage, HumanMessage, SystemMessage"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Bot] Paris\n"
]
}
],
"source": [
"# service url\n",
"service_url = \"https://b008-54-186-154-209.ngrok-free.app\"\n",
"\n",
"# create wasm-chat service instance\n",
"chat = WasmChatService(service_url=service_url)\n",
"\n",
"# create message sequence\n",
"system_message = SystemMessage(content=\"You are an AI assistant\")\n",
"user_message = HumanMessage(content=\"What is the capital of France?\")\n",
"messages = [system_message, user_message]\n",
"\n",
"# chat with wasm-chat service\n",
"response = chat(messages)\n",
"\n",
"print(f\"[Bot] {response.content}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 2 additions & 0 deletions libs/community/langchain_community/chat_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@
from langchain_community.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from langchain_community.chat_models.vertexai import ChatVertexAI
from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat
from langchain_community.chat_models.wasm_chat import WasmChatService
from langchain_community.chat_models.yandex import ChatYandexGPT

__all__ = [
"WasmChatService",
"ChatOpenAI",
"BedrockChat",
"AzureChatOpenAI",
Expand Down
144 changes: 144 additions & 0 deletions libs/community/langchain_community/chat_models/wasm_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import json
import logging
from typing import Any, Dict, List, Mapping, Optional

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import get_pydantic_field_names

logger = logging.getLogger(__name__)


def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict.get("content", "") or "")
else:
return ChatMessage(content=_dict["content"], role=role)


def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
else:
raise TypeError(f"Got unknown type {message}")

return message_dict


class WasmChatService(BaseChatModel):
"""Chat with LLMs via `llama-api-server`

For the information about `llama-api-server`, visit https://github.com/second-state/llama-utils
"""

request_timeout: int = 60
"""request timeout for chat http requests"""
service_url: Optional[str] = None
"""URL of WasmChat service"""
model: str = "NA"
"""model name, default is `NA`."""

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)

invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)

values["model_kwargs"] = extra
return values

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
res = self._chat(messages, **kwargs)

if res.status_code != 200:
raise ValueError(f"Error code: {res.status_code}, reason: {res.reason}")

response = res.json()

return self._create_chat_result(response)

def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.service_url is None:
res = requests.models.Response()
res.status_code = 503
res.reason = "The IP address or port of the chat service is incorrect."
return res

service_url = f"{self.service_url}/v1/chat/completions"

payload = {
"model": self.model,
"messages": [_convert_message_to_dict(m) for m in messages],
}

res = requests.post(
url=service_url,
timeout=self.request_timeout,
headers={
"accept": "application/json",
"Content-Type": "application/json",
},
data=json.dumps(payload),
)

return res

def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
message = _convert_dict_to_message(response["choices"][0].get("message"))
generations = [ChatGeneration(message=message)]

token_usage = response["usage"]
llm_output = {"token_usage": token_usage, "model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)

@property
def _llm_type(self) -> str:
return "wasm-chat"
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

from langchain_community.chat_models.wasm_chat import WasmChatService


@pytest.mark.enable_socket
def test_chat_wasm_service() -> None:
"""This test requires the port 8080 is not occupied."""

# service url
service_url = "https://b008-54-186-154-209.ngrok-free.app"

# create wasm-chat service instance
chat = WasmChatService(service_url=service_url)

# create message sequence
system_message = SystemMessage(content="You are an AI assistant")
user_message = HumanMessage(content="What is the capital of France?")
messages = [system_message, user_message]

# chat with wasm-chat service
response = chat(messages)

# check response
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "Paris" in response.content
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"ChatHunyuan",
"GigaChat",
"VolcEngineMaasChat",
"WasmChatService",
"GPTRouter",
]

Expand Down
78 changes: 78 additions & 0 deletions libs/community/tests/unit_tests/chat_models/test_wasmchat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
from langchain_core.messages import (
AIMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)

from langchain_community.chat_models.wasm_chat import (
WasmChatService,
_convert_dict_to_message,
_convert_message_to_dict,
)


def test__convert_message_to_dict_human() -> None:
message = HumanMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "user", "content": "foo"}
assert result == expected_output


def test__convert_message_to_dict_ai() -> None:
message = AIMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "assistant", "content": "foo"}
assert result == expected_output


def test__convert_message_to_dict_system() -> None:
message = SystemMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "system", "content": "foo"}
assert result == expected_output


def test__convert_message_to_dict_function() -> None:
message = FunctionMessage(name="foo", content="bar")
with pytest.raises(TypeError) as e:
_convert_message_to_dict(message)
assert "Got unknown type" in str(e)


def test__convert_dict_to_message_human() -> None:
message_dict = {"role": "user", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = HumanMessage(content="foo")
assert result == expected_output


def test__convert_dict_to_message_ai() -> None:
message_dict = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = AIMessage(content="foo")
assert result == expected_output


def test__convert_dict_to_message_other_role() -> None:
message_dict = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = ChatMessage(role="system", content="foo")
assert result == expected_output


def test_wasm_chat_without_service_url() -> None:
chat = WasmChatService()

# create message sequence
system_message = SystemMessage(content="You are an AI assistant")
user_message = HumanMessage(content="What is the capital of France?")
messages = [system_message, user_message]

with pytest.raises(ValueError) as e:
chat(messages)

assert "Error code: 503" in str(e)
assert "reason: The IP address or port of the chat service is incorrect." in str(e)
Loading