forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New add Baichuan Model (langchain-ai#11714)
Motivation and Context At present, the Baichuan Large Language Model is relatively popular and efficient in performance. Due to widespread market recognition, this model has been added to enhance the scalability of Langchain's ability to access the big language model, so as to facilitate application access and usage for interested users. System Info langchain: 0.0.295 python:3.8.3 IDE:vs code Description Add the following files: 1. Add baichuan_baichuaninc_endpoint.py in the libs/langchain/langchain/chat_models 2. Modify the __init__.py file,which is located in the libs/langchain/langchain/chat_models/__init__.py: a. Add "from langchain.chat_models.baichuan_baichuaninc_endpoint import BaichuanChatEndpoint" b. Add "BaichuanChatEndpoint" In the file's __ All__ method Your contribution I am willing to help implement this feature and submit a PR, but I would appreciate guidance from the maintainers or community to ensure the changes are made correctly and in line with the project's standards and practices.
- Loading branch information
Showing
3 changed files
with
258 additions
and
0 deletions.
There are no files selected for viewing
95 changes: 95 additions & 0 deletions
95
docs/docs/integrations/chat/baichuan_baichuaninc_endpoint.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Baichuan Baihuan2\n", | ||
"\n", | ||
"Baichuan Intelligent announced the official open-source fine-tuning of Baihuan2-7B, Baihuan2-13B, Baihuan2-13B-Chat and their 4-bit quantified versions, all of which are free and commercially available.\n", | ||
"\n", | ||
"According to the introduction, both Baihuan2-7B-Base and Baihuan2-13B-Base are trained on 2.6 trillion high-quality multilingual data. While retaining the excellent generation and creation capabilities of the previous generation open source model, smooth multi round dialogue ability, and low deployment threshold, the two models have significantly improved their mathematical, code, security, logical reasoning, semantic understanding, and other abilities. Compared to the previous generation 13B model, Baihuan2-13B-Base has improved mathematical ability by 49%, code ability by 46%, security ability by 37%, logical reasoning ability by 25%, and semantic understanding ability by 15%.\n", | ||
"\n", | ||
"Basically, these models are classified into the following types:\n", | ||
"\n", | ||
"- Chat\n", | ||
"- Completion\n", | ||
"\n", | ||
"In this notebook, we will introduce how to use langchain with [Baichuan](https://api.baichuan-ai.com) mainly in `Chat` corresponding\n", | ||
" to the package `langchain/chat_models` in langchain:\n", | ||
"\n", | ||
"\n", | ||
"## API Initialization\n", | ||
"\n", | ||
"To use the LLM services based on Baichuan Baihuan2, you have to initialize these parameters:\n", | ||
"\n", | ||
"To use a wrapper, the following parameters must be set in your environment variable:\n", | ||
"\n", | ||
"```base\n", | ||
"Baichuan_AK=API_Key\n", | ||
"Baichuan_SK=secret_Key\n", | ||
"```\n", | ||
"\n", | ||
"Both of the above need to be applied for at https://api.baichuan-ai.com\n", | ||
"\n", | ||
"## Current supported models:\n", | ||
"\n", | ||
"- Baichuan2-7B\n", | ||
"- Baichuan2-13B\n", | ||
"- Baichuan2-13B-Chat" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## requesting llm api endpoint:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\"\"\"For basic init and call\"\"\"\n", | ||
"import os\n", | ||
"from langchain.chat_models import BaichuanChatEndpoint\n", | ||
"\n", | ||
"baichuan_ak = os.getenv('Baichuan_AK')\n", | ||
"baichuan_sk = os.getenv('Baichuan_SK') \n", | ||
"\n", | ||
"chat_model = BaichuanChatEndpoint(baichuan_ak, baichuan_sk, \"Baichuan2-13B\")\n", | ||
"res = chat_model.predict(\"Hello, please introduce yourself!\")\n", | ||
"print(f\"Answer:{res.text}\")\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "base", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
161 changes: 161 additions & 0 deletions
161
libs/langchain/langchain/chat_models/baichuan_baichuaninc_endpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
"""Baichuan chat wrapper.""" | ||
from __future__ import annotations | ||
|
||
import requests | ||
import json | ||
import time | ||
import hashlib | ||
import logging | ||
from typing import ( | ||
Any, | ||
AsyncIterator, | ||
Dict, | ||
Iterator, | ||
List, | ||
Mapping, | ||
Optional, | ||
) | ||
|
||
from langchain.callbacks.manager import ( | ||
AsyncCallbackManagerForLLMRun, | ||
CallbackManagerForLLMRun, | ||
) | ||
from langchain.chat_models.base import BaseChatModel | ||
from langchain.pydantic_v1 import Field, root_validator | ||
from langchain.schema import ChatGeneration, ChatResult | ||
from langchain.schema.messages import ( | ||
AIMessage, | ||
AIMessageChunk, | ||
BaseMessage, | ||
BaseMessageChunk, | ||
ChatMessage, | ||
FunctionMessage, | ||
HumanMessage, | ||
SystemMessage, | ||
) | ||
from langchain.schema.output import ChatGenerationChunk | ||
from langchain.utils import get_from_dict_or_env | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk: | ||
return AIMessageChunk( | ||
content=resp["result"], | ||
role="assistant", | ||
) | ||
|
||
|
||
def convert_message_to_dict(message: BaseMessage) -> dict: | ||
"""Convert a message to a dictionary that can be passed to the API.""" | ||
message_dict: Dict[str, Any] | ||
if isinstance(message, ChatMessage): | ||
message_dict = {"role": message.role, "content": message.content} | ||
elif isinstance(message, HumanMessage): | ||
message_dict = {"role": "user", "content": message.content} | ||
elif isinstance(message, AIMessage): | ||
message_dict = {"role": "assistant", "content": message.content} | ||
if "function_call" in message.additional_kwargs: | ||
message_dict["functions"] = message.additional_kwargs["function_call"] | ||
# If function call only, content is None not empty string | ||
if message_dict["content"] == "": | ||
message_dict["content"] = None | ||
elif isinstance(message, FunctionMessage): | ||
message_dict = { | ||
"role": "function", | ||
"content": message.content, | ||
"name": message.name, | ||
} | ||
else: | ||
raise TypeError(f"Got unknown type {message}") | ||
|
||
return message_dict | ||
|
||
|
||
def calculate_md5(input_string): | ||
md5 = hashlib.md5() | ||
md5.update(input_string.encode('utf-8')) | ||
encrypted = md5.hexdigest() | ||
return encrypted | ||
|
||
|
||
class BaichuanChatEndpoint(): | ||
""" | ||
Currently only enterprise registration is supported for use | ||
To use, you should have the environment variable ``Baichuan_AK`` and ``Baichuan_SK`` set with your | ||
api_key and secret_key. | ||
ak, sk are required parameters | ||
which you could get from https: // api.baichuan-ai.com | ||
Example: | ||
.. code-block: : python | ||
from langchain.chat_models import BaichuanChatEndpoint | ||
baichuan_chat = BaichuanChatEndpoint("your_ak", "your_sk","Baichuan2-13B") | ||
result=baichuan_chat.predict(message) | ||
print(result.text") | ||
Because Baichuan was no pip package made,So we will temporarily use this method and iterate and upgrade in the future | ||
Args: They cannot be empty | ||
baichuan_ak (str): api_key | ||
baichuan_sk (str): secret_key | ||
model (str): Default Baichuan2-7B,Baichuan2-13B,Baichuan2-53B which is commercial. | ||
streaming (bool): Defautlt False | ||
Returns: | ||
Execute predict return response. | ||
""" | ||
|
||
baichuan_ak: Optional[str] = None | ||
baichuan_sk: Optional[str] = None | ||
|
||
request_timeout: Optional[int] = 60 | ||
"""request timeout for chat http requests""" | ||
|
||
top_p: Optional[float] = 0.8 | ||
temperature: Optional[float] = 0.95 | ||
|
||
endpoint: Optional[str] = None | ||
"""Endpoint of the Qianfan LLM, required if custom model used.""" | ||
|
||
def __init__(self, baichuan_ak, baichuan_sk, model="Baichuan2-7B", streaming=False): | ||
self.baichuan_ak = baichuan_ak | ||
self.baichuan_sk = baichuan_sk | ||
self.model = "Baichuan2-7B" if model is None else model | ||
self.streaming = False if streaming is not None and streaming is False else True | ||
|
||
def predict(self, messages: List[BaseMessage]) -> Response: | ||
|
||
if self.streaming is not None and self.streaming is False: | ||
url = "https://api.baichuan-ai.com/v1/chat" | ||
elif self.streaming is not None and self.streaming is True: | ||
url = "https://api.baichuan-ai.com/v1/stream/chat" | ||
|
||
data = { | ||
"model": self.model, | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": messages | ||
} | ||
] | ||
} | ||
|
||
json_data = json.dumps(data) | ||
time_stamp = int(time.time()) | ||
signature = calculate_md5(self.baichuan_sk + json_data + str(time_stamp)) | ||
|
||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": "Bearer " + self.baichuan_ak, | ||
"X-BC-Request-Id": "your requestId", | ||
"X-BC-Timestamp": str(time_stamp), | ||
"X-BC-Signature": signature, | ||
"X-BC-Sign-Algo": "MD5", | ||
} | ||
response = requests.post(url, data=json_data, headers=headers) | ||
return response |