Skip to content

Commit

Permalink
New add Baichuan Model (langchain-ai#11714)
Browse files Browse the repository at this point in the history
Motivation and Context
At present, the Baichuan Large Language Model is relatively popular and
efficient in performance. Due to widespread market recognition, this
model has been added to enhance the scalability of Langchain's ability
to access the big language model, so as to facilitate application access
and usage for interested users.

System Info
langchain: 0.0.295
python:3.8.3
IDE:vs code

Description
Add the following files:

1. Add baichuan_baichuaninc_endpoint.py in the
libs/langchain/langchain/chat_models
2. Modify the __init__.py file,which is located in the
libs/langchain/langchain/chat_models/__init__.py:
a. Add "from langchain.chat_models.baichuan_baichuaninc_endpoint import
BaichuanChatEndpoint"
    b. Add "BaichuanChatEndpoint" In the file's __ All__  method

Your contribution
I am willing to help implement this feature and submit a PR, but I would
appreciate guidance from the maintainers or community to ensure the
changes are made correctly and in line with the project's standards and
practices.
  • Loading branch information
AIGCool authored Oct 13, 2023
1 parent 694d768 commit 56653c5
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 0 deletions.
95 changes: 95 additions & 0 deletions docs/docs/integrations/chat/baichuan_baichuaninc_endpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Baichuan Baihuan2\n",
"\n",
"Baichuan Intelligent announced the official open-source fine-tuning of Baihuan2-7B, Baihuan2-13B, Baihuan2-13B-Chat and their 4-bit quantified versions, all of which are free and commercially available.\n",
"\n",
"According to the introduction, both Baihuan2-7B-Base and Baihuan2-13B-Base are trained on 2.6 trillion high-quality multilingual data. While retaining the excellent generation and creation capabilities of the previous generation open source model, smooth multi round dialogue ability, and low deployment threshold, the two models have significantly improved their mathematical, code, security, logical reasoning, semantic understanding, and other abilities. Compared to the previous generation 13B model, Baihuan2-13B-Base has improved mathematical ability by 49%, code ability by 46%, security ability by 37%, logical reasoning ability by 25%, and semantic understanding ability by 15%.\n",
"\n",
"Basically, these models are classified into the following types:\n",
"\n",
"- Chat\n",
"- Completion\n",
"\n",
"In this notebook, we will introduce how to use langchain with [Baichuan](https://api.baichuan-ai.com) mainly in `Chat` corresponding\n",
" to the package `langchain/chat_models` in langchain:\n",
"\n",
"\n",
"## API Initialization\n",
"\n",
"To use the LLM services based on Baichuan Baihuan2, you have to initialize these parameters:\n",
"\n",
"To use a wrapper, the following parameters must be set in your environment variable:\n",
"\n",
"```base\n",
"Baichuan_AK=API_Key\n",
"Baichuan_SK=secret_Key\n",
"```\n",
"\n",
"Both of the above need to be applied for at https://api.baichuan-ai.com\n",
"\n",
"## Current supported models:\n",
"\n",
"- Baichuan2-7B\n",
"- Baichuan2-13B\n",
"- Baichuan2-13B-Chat"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## requesting llm api endpoint:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"For basic init and call\"\"\"\n",
"import os\n",
"from langchain.chat_models import BaichuanChatEndpoint\n",
"\n",
"baichuan_ak = os.getenv('Baichuan_AK')\n",
"baichuan_sk = os.getenv('Baichuan_SK') \n",
"\n",
"chat_model = BaichuanChatEndpoint(baichuan_ak, baichuan_sk, \"Baichuan2-13B\")\n",
"res = chat_model.predict(\"Hello, please introduce yourself!\")\n",
"print(f\"Answer:{res.text}\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"vscode": {
"interpreter": {
"hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 2 additions & 0 deletions libs/langchain/langchain/chat_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from langchain.chat_models.anyscale import ChatAnyscale
from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain.chat_models.baichuan_baichuaninc_endpoint import BaichuanChatEndpoint
from langchain.chat_models.bedrock import BedrockChat
from langchain.chat_models.cohere import ChatCohere
from langchain.chat_models.ernie import ErnieBotChat
Expand Down Expand Up @@ -61,4 +62,5 @@
"ChatKonko",
"QianfanChatEndpoint",
"ChatFireworks",
"BaichuanChatEndpoint",
]
161 changes: 161 additions & 0 deletions libs/langchain/langchain/chat_models/baichuan_baichuaninc_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Baichuan chat wrapper."""
from __future__ import annotations

import requests
import json
import time
import hashlib
import logging
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Mapping,
Optional,
)

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env

logger = logging.getLogger(__name__)


def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
return AIMessageChunk(
content=resp["result"],
role="assistant",
)


def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a message to a dictionary that can be passed to the API."""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["functions"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise TypeError(f"Got unknown type {message}")

return message_dict


def calculate_md5(input_string):
md5 = hashlib.md5()
md5.update(input_string.encode('utf-8'))
encrypted = md5.hexdigest()
return encrypted


class BaichuanChatEndpoint():
"""
Currently only enterprise registration is supported for use
To use, you should have the environment variable ``Baichuan_AK`` and ``Baichuan_SK`` set with your
api_key and secret_key.
ak, sk are required parameters
which you could get from https: // api.baichuan-ai.com
Example:
.. code-block: : python
from langchain.chat_models import BaichuanChatEndpoint
baichuan_chat = BaichuanChatEndpoint("your_ak", "your_sk","Baichuan2-13B")
result=baichuan_chat.predict(message)
print(result.text")
Because Baichuan was no pip package made,So we will temporarily use this method and iterate and upgrade in the future
Args: They cannot be empty
baichuan_ak (str): api_key
baichuan_sk (str): secret_key
model (str): Default Baichuan2-7B,Baichuan2-13B,Baichuan2-53B which is commercial.
streaming (bool): Defautlt False

Check failure on line 106 in libs/langchain/langchain/chat_models/baichuan_baichuaninc_endpoint.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

Defautlt ==> Default
Returns:
Execute predict return response.
"""

baichuan_ak: Optional[str] = None
baichuan_sk: Optional[str] = None

request_timeout: Optional[int] = 60
"""request timeout for chat http requests"""

top_p: Optional[float] = 0.8
temperature: Optional[float] = 0.95

endpoint: Optional[str] = None
"""Endpoint of the Qianfan LLM, required if custom model used."""

def __init__(self, baichuan_ak, baichuan_sk, model="Baichuan2-7B", streaming=False):
self.baichuan_ak = baichuan_ak
self.baichuan_sk = baichuan_sk
self.model = "Baichuan2-7B" if model is None else model
self.streaming = False if streaming is not None and streaming is False else True

def predict(self, messages: List[BaseMessage]) -> Response:

if self.streaming is not None and self.streaming is False:
url = "https://api.baichuan-ai.com/v1/chat"
elif self.streaming is not None and self.streaming is True:
url = "https://api.baichuan-ai.com/v1/stream/chat"

data = {
"model": self.model,
"messages": [
{
"role": "user",
"content": messages
}
]
}

json_data = json.dumps(data)
time_stamp = int(time.time())
signature = calculate_md5(self.baichuan_sk + json_data + str(time_stamp))

headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.baichuan_ak,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(time_stamp),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
response = requests.post(url, data=json_data, headers=headers)
return response

0 comments on commit 56653c5

Please sign in to comment.