Skip to content

Commit

Permalink
feat: 1. Add system parameters, 2. Align with the QianfanChatEndpoint…
Browse files Browse the repository at this point in the history
… for function calling (#14275)

- **Description:** 
1. Add system parameters to the ERNIE LLM API to set the role of the
LLM.
2. Add support for the ERNIE-Bot-turbo-AI model according from the
document https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Alp0kdm0n.
3. For the function call of ErnieBotChat, align with the
QianfanChatEndpoint.

With this PR, the `QianfanChatEndpoint()` can use the `function calling`
ability with `create_ernie_fn_chain()`. The example is as the following:

```
from langchain.prompts import ChatPromptTemplate
import json
from langchain.prompts.chat import (
    ChatPromptTemplate,
)

from langchain.chat_models import QianfanChatEndpoint
from langchain.chains.ernie_functions import (
    create_ernie_fn_chain,
)

def get_current_news(location: str) -> str:
    """Get the current news based on the location.'

    Args:
        location (str): The location to query.
    
    Returs:
        str: Current news based on the location.
    """

    news_info = {
        "location": location,
        "news": [
            "I have a Book.",
            "It's a nice day, today."
        ]
    }

    return json.dumps(news_info)

def get_current_weather(location: str, unit: str="celsius") -> str:
    """Get the current weather in a given location

    Args:
        location (str): location of the weather.
        unit (str): unit of the tempuature.
    
    Returns:
        str: weather in the given location.
    """

    weather_info = {
        "location": location,
        "temperature": "27",
        "unit": unit,
        "forecast": ["sunny", "windy"],
    }
    return json.dumps(weather_info)

template = ChatPromptTemplate.from_messages([
    ("user", "{user_input}"),
])

chat = QianfanChatEndpoint(model="ERNIE-Bot-4")
chain = create_ernie_fn_chain([get_current_weather, get_current_news], chat, template, verbose=True)
res = chain.run("北京今天的新闻是什么?")
print(res)
```

The result of the above code:
```
> Entering new LLMChain chain...
Prompt after formatting:
Human: 北京今天的新闻是什么?
> Finished chain.
{'name': 'get_current_news', 'arguments': {'location': '北京'}}
```

For the `ErnieBotChat`, now can use the `system` parameter to set the
role of the LLM.

```
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
from langchain.chat_models import ErnieBotChat

llm = ErnieBotChat(model_name="ERNIE-Bot-turbo-AI", system="你是一个能力很强的机器人,你的名字叫 小叮当。无论问你什么问题,你都可以给出答案。")
prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{query}"),
    ]
)
chain = LLMChain(llm=llm, prompt=prompt, verbose=True)
res = chain.run(query="你是谁?")
print(res)
```

The result of the above code:

```
> Entering new LLMChain chain...
Prompt after formatting:
Human: 你是谁?
> Finished chain.
我是小叮当,一个智能机器人。我可以为你提供各种服务,包括回答问题、提供信息、进行计算等。如果你需要任何帮助,请随时告诉我,我会尽力为你提供最好的服务。
```
  • Loading branch information
wangwei1237 authored Dec 6, 2023
1 parent fd5be55 commit 7205bfd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
30 changes: 21 additions & 9 deletions libs/langchain/langchain/chat_models/ernie.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import threading
from typing import Any, Dict, List, Mapping, Optional
Expand Down Expand Up @@ -46,7 +45,8 @@ class ErnieBotChat(BaseChatModel):
and will be regenerated after expiration (30 days).
Default model is `ERNIE-Bot-turbo`,
currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`
currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`, `ERNIE-Bot-8K`,
`ERNIE-Bot-4`, `ERNIE-Bot-turbo-AI`.
Example:
.. code-block:: python
Expand Down Expand Up @@ -87,6 +87,11 @@ class ErnieBotChat(BaseChatModel):
"""model name of ernie, default is `ERNIE-Bot-turbo`.
Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`"""

system: Optional[str] = None
"""system is mainly used for model character design,
for example, you are an AI assistant produced by xxx company.
The length of the system is limiting of 1024 characters."""

request_timeout: Optional[int] = 60
"""request timeout for chat http requests"""

Expand Down Expand Up @@ -123,6 +128,7 @@ def _chat(self, payload: object) -> dict:
"ERNIE-Bot": "completions",
"ERNIE-Bot-8K": "ernie_bot_8k",
"ERNIE-Bot-4": "completions_pro",
"ERNIE-Bot-turbo-AI": "ai_apaas",
"BLOOMZ-7B": "bloomz_7b1",
"Llama-2-7b-chat": "llama_2_7b",
"Llama-2-13b-chat": "llama_2_13b",
Expand Down Expand Up @@ -180,6 +186,7 @@ def _generate(
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_score": self.penalty_score,
"system": self.system,
**kwargs,
}
logger.debug(f"Payload for ernie api is {payload}")
Expand All @@ -195,14 +202,19 @@ def _generate(

def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
if "function_call" in response:
fc_str = '{{"function_call": {}}}'.format(
json.dumps(response.get("function_call"))
)
generations = [ChatGeneration(message=AIMessage(content=fc_str))]
additional_kwargs = {
"function_call": dict(response.get("function_call", {}))
}
else:
generations = [
ChatGeneration(message=AIMessage(content=response.get("result")))
]
additional_kwargs = {}
generations = [
ChatGeneration(
message=AIMessage(
content=response.get("result"),
additional_kwargs={**additional_kwargs},
)
)
]
token_usage = response.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
return ChatResult(generations=generations, llm_output=llm_output)
Expand Down
8 changes: 2 additions & 6 deletions libs/langchain/langchain/output_parsers/ernie_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,8 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An
"This output parser can only be used with a chat generation."
)
message = generation.message
message.additional_kwargs["function_call"] = {}
if "function_call" in message.content:
function_call = json.loads(str(message.content))
if "function_call" in function_call:
fc = function_call["function_call"]
message.additional_kwargs["function_call"] = fc
if "function_call" not in message.additional_kwargs:
return None
try:
function_call = message.additional_kwargs["function_call"]
except KeyError as exc:
Expand Down

0 comments on commit 7205bfd

Please sign in to comment.