Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

升级 gbi 使用 Component 的 http_client #26

Merged
merged 2 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions appbuilder/core/components/gbi/nl2sql/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

r"""GBI nl2sql component.
"""
import uuid
import json
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, ValidationError

Expand Down Expand Up @@ -90,11 +88,12 @@ def run(self,
1. query: 用户问题
2. session: gbi session 的历史 列表, 参考 SessionRecord
3. column_constraint: 列选约束 参考 ColumnItem 具体定义
timeout: 超时时间
retry: 重试次数
Returns:
NL2SqlResult 的 message
"""


try:
inputs = self.meta(**message.content)
except ValidationError as e:
Expand Down Expand Up @@ -134,11 +133,11 @@ def _run_nl2sql(self, query: str, session: List[SessionRecord], table_schemas: L

"""

headers = self.auth_header()
headers = self.http_client.auth_header()
headers["Content-Type"] = "application/json"

if retry != self.retry.total:
self.retry.total = retry
if retry != self.http_client.retry.total:
self.http_client.retry.total = retry

payload = {"query": query,
"table_schemas": table_schemas,
Expand All @@ -148,13 +147,13 @@ def _run_nl2sql(self, query: str, session: List[SessionRecord], table_schemas: L
"knowledge": knowledge,
"prompt_template": prompt_template}

server_url = self.service_url(prefix="", sub_path=self.server_sub_path)
response = self.s.post(url=server_url, headers=headers,
json=payload, timeout=timeout)
super().check_response_header(response)
server_url = self.http_client.service_url(prefix="", sub_path=self.server_sub_path)
response = self.http_client.session.post(url=server_url, headers=headers,
json=payload, timeout=timeout)
self.http_client.check_response_header(response)
data = response.json()
super().check_response_json(data)
self.http_client.check_response_json(data)

request_id = self.response_request_id(response)
request_id = self.http_client.response_request_id(response)
response.request_id = request_id
return response
29 changes: 13 additions & 16 deletions appbuilder/core/components/gbi/select_table/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

r"""GBI nl2sql component.
"""
import uuid
import json
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, ValidationError

Expand Down Expand Up @@ -67,8 +65,6 @@ def __init__(self, model_name: str, table_descriptions: Dict[str, str],
问题:{query}
回答:
```
secret_key:
gateway:
"""
super().__init__(meta=SelectTableArgs)
if model_name not in SUPPORTED_MODEL_NAME:
Expand All @@ -79,12 +75,14 @@ def __init__(self, model_name: str, table_descriptions: Dict[str, str],
self.prompt_template = prompt_template

def run(self,
message: Message, timeout: int = 60,retry: int = 0) -> Message[List[str]]:
message: Message, timeout: int = 60, retry: int = 0) -> Message[List[str]]:
"""
Args:
message: message.content 字典包含 key:
1. query - 用户的问题输入
2. session - 对话历史, 可选
timeout: 超时时间
retry: 重试次数

Returns: 识别的表名的列表 ["table_name"]
"""
Expand Down Expand Up @@ -122,26 +120,25 @@ def _run_select_table(self, query: str, session: List[SessionRecord],
obj:`ShortSpeechRecognitionResponse`: 接口返回的输出消息。
"""

headers = self.auth_header()
headers = self.http_client.auth_header()
headers["Content_Type"] = "application/json"

if retry != self.retry.total:
self.retry.total = retry
if retry != self.http_client.retry.total:
self.http_client.retry.total = retry

payload = {"query": query,
"table_descriptions": table_descriptions,
"session": [session_record.to_json() for session_record in session],
"session": [session_record.dict() for session_record in session],
"model_name": model_name,
"prompt_template": prompt_template}

server_url = self.service_url(sub_path=self.server_sub_path)
response = self.s.post(url=server_url, headers=headers,
json=payload, timeout=timeout)
super().check_response_header(response)
server_url = self.http_client.service_url(sub_path=self.server_sub_path)
response = self.http_client.session.post(url=server_url, headers=headers,
json=payload, timeout=timeout)
self.http_client.check_response_header(response)
data = response.json()
super().check_response_json(data)
self.http_client.check_response_json(data)

request_id = self.response_request_id(response)
request_id = self.http_client.response_request_id(response)
response.request_id = request_id
return response

8 changes: 4 additions & 4 deletions appbuilder/tests/test_gbi_nl2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def test_run_with_session(self):
session = list()
session_record = SessionRecord(query="列出商品类别是水果的的利润率",
answer=NL2SqlResult(
llm_result="根据问题分析得到 sql 如下: \n "
"```sql\nSELECT * FROM `超市营收明细` "
"WHERE `商品类别` = '水果'\n```",
sql="SELECT * FROM `超市营收明细` WHERE `商品类别` = '水果'"))
llm_result="根据问题分析得到 sql 如下: \n "
"```sql\nSELECT * FROM `超市营收明细` "
"WHERE `商品类别` = '水果'\n```",
sql="SELECT * FROM `超市营收明细` WHERE `商品类别` = '水果'"))
session.append(session_record)

query = "列出所有的商品类别"
Expand Down
26 changes: 23 additions & 3 deletions appbuilder/tests/test_gbi_select_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import appbuilder
from appbuilder.core.message import Message
from appbuilder.core.components.gbi.basic import SessionRecord

from appbuilder.core.components.gbi.basic import NL2SqlResult

SUPER_MARKET_SCHEMA = """
```
Expand Down Expand Up @@ -68,6 +68,7 @@
回答:
"""


class TestGBISelectTable(unittest.TestCase):

def setUp(self):
Expand All @@ -79,7 +80,7 @@ def setUp(self):
self.select_table_node = \
appbuilder.SelectTable(model_name=model_name,
table_descriptions={"supper_market_info": "超市营收明细表,包含超市各种信息等",
"product_sales_info": "产品销售表"})
"product_sales_info": "产品销售表"})

def test_run_with_default_param(self):
"""测试 run 方法使用有效参数"""
Expand All @@ -95,7 +96,6 @@ def test_run_with_prompt_template(self):
"""测试 run 方法中 prompt template 模版"""
query = "列出超市中的所有数据"
msg = Message({"query": query})
result_message = self.select_table_node(message=msg)
self.select_table_node.prompt_template = PROMPT_TEMPLATE
result_message = self.select_table_node(msg)

Expand All @@ -104,6 +104,26 @@ def test_run_with_prompt_template(self):
self.assertEqual(result_message.content[0], "supper_market_info")
self.select_table_node.prompt_template = ""

def test_run_with_session(self):
"""测试 run 方法中 prompt template 模版"""

session = list()
session_record = SessionRecord(query="列出商品类别是水果的的利润率",
answer=NL2SqlResult(
llm_result="根据问题分析得到 sql 如下: \n "
"```sql\nSELECT * FROM `超市营收明细` "
"WHERE `商品类别` = '水果'\n```",
sql="SELECT * FROM `超市营收明细` WHERE `商品类别` = '水果'"))
session.append(session_record)

query = "列出超市中的所有数据"
msg = Message({"query": query, "session": session})
result_message = self.select_table_node(msg)

self.assertIsNotNone(result_message)
self.assertEqual(len(result_message.content), 1)
self.assertEqual(result_message.content[0], "supper_market_info")


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion cookbooks/gbi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"# GBI\n",
"\n",
"## 目标\n",
"通过 GBI sdk 接口完成选表和问表的能力。\n",
"通过 GBI sdk 接口完成选表和问表的能力。 \n",
"\n",
"## 准备工作\n",
"### 平台注册\n",
Expand Down
Loading