diff --git a/appbuilder/core/components/gbi/nl2sql/component.py b/appbuilder/core/components/gbi/nl2sql/component.py index b90f8066..9410a006 100644 --- a/appbuilder/core/components/gbi/nl2sql/component.py +++ b/appbuilder/core/components/gbi/nl2sql/component.py @@ -14,8 +14,6 @@ r"""GBI nl2sql component. """ -import uuid -import json from typing import Dict, List, Optional from pydantic import BaseModel, Field, ValidationError @@ -90,11 +88,12 @@ def run(self, 1. query: 用户问题 2. session: gbi session 的历史 列表, 参考 SessionRecord 3. column_constraint: 列选约束 参考 ColumnItem 具体定义 + timeout: 超时时间 + retry: 重试次数 Returns: NL2SqlResult 的 message """ - try: inputs = self.meta(**message.content) except ValidationError as e: @@ -134,11 +133,11 @@ def _run_nl2sql(self, query: str, session: List[SessionRecord], table_schemas: L """ - headers = self.auth_header() + headers = self.http_client.auth_header() headers["Content-Type"] = "application/json" - if retry != self.retry.total: - self.retry.total = retry + if retry != self.http_client.retry.total: + self.http_client.retry.total = retry payload = {"query": query, "table_schemas": table_schemas, @@ -148,13 +147,13 @@ def _run_nl2sql(self, query: str, session: List[SessionRecord], table_schemas: L "knowledge": knowledge, "prompt_template": prompt_template} - server_url = self.service_url(prefix="", sub_path=self.server_sub_path) - response = self.s.post(url=server_url, headers=headers, - json=payload, timeout=timeout) - super().check_response_header(response) + server_url = self.http_client.service_url(prefix="", sub_path=self.server_sub_path) + response = self.http_client.session.post(url=server_url, headers=headers, + json=payload, timeout=timeout) + self.http_client.check_response_header(response) data = response.json() - super().check_response_json(data) + self.http_client.check_response_json(data) - request_id = self.response_request_id(response) + request_id = self.http_client.response_request_id(response) response.request_id = request_id return response diff --git a/appbuilder/core/components/gbi/select_table/component.py b/appbuilder/core/components/gbi/select_table/component.py index d2550447..5bb0c516 100644 --- a/appbuilder/core/components/gbi/select_table/component.py +++ b/appbuilder/core/components/gbi/select_table/component.py @@ -14,8 +14,6 @@ r"""GBI nl2sql component. """ -import uuid -import json from typing import Dict, List, Optional from pydantic import BaseModel, Field, ValidationError @@ -67,8 +65,6 @@ def __init__(self, model_name: str, table_descriptions: Dict[str, str], 问题:{query} 回答: ``` - secret_key: - gateway: """ super().__init__(meta=SelectTableArgs) if model_name not in SUPPORTED_MODEL_NAME: @@ -79,12 +75,14 @@ def __init__(self, model_name: str, table_descriptions: Dict[str, str], self.prompt_template = prompt_template def run(self, - message: Message, timeout: int = 60,retry: int = 0) -> Message[List[str]]: + message: Message, timeout: int = 60, retry: int = 0) -> Message[List[str]]: """ Args: message: message.content 字典包含 key: 1. query - 用户的问题输入 2. session - 对话历史, 可选 + timeout: 超时时间 + retry: 重试次数 Returns: 识别的表名的列表 ["table_name"] """ @@ -122,26 +120,25 @@ def _run_select_table(self, query: str, session: List[SessionRecord], obj:`ShortSpeechRecognitionResponse`: 接口返回的输出消息。 """ - headers = self.auth_header() + headers = self.http_client.auth_header() headers["Content_Type"] = "application/json" - if retry != self.retry.total: - self.retry.total = retry + if retry != self.http_client.retry.total: + self.http_client.retry.total = retry payload = {"query": query, "table_descriptions": table_descriptions, - "session": [session_record.to_json() for session_record in session], + "session": [session_record.dict() for session_record in session], "model_name": model_name, "prompt_template": prompt_template} - server_url = self.service_url(sub_path=self.server_sub_path) - response = self.s.post(url=server_url, headers=headers, - json=payload, timeout=timeout) - super().check_response_header(response) + server_url = self.http_client.service_url(sub_path=self.server_sub_path) + response = self.http_client.session.post(url=server_url, headers=headers, + json=payload, timeout=timeout) + self.http_client.check_response_header(response) data = response.json() - super().check_response_json(data) + self.http_client.check_response_json(data) - request_id = self.response_request_id(response) + request_id = self.http_client.response_request_id(response) response.request_id = request_id return response - diff --git a/appbuilder/tests/test_gbi_nl2sql.py b/appbuilder/tests/test_gbi_nl2sql.py index 5617f157..64c5931c 100644 --- a/appbuilder/tests/test_gbi_nl2sql.py +++ b/appbuilder/tests/test_gbi_nl2sql.py @@ -145,10 +145,10 @@ def test_run_with_session(self): session = list() session_record = SessionRecord(query="列出商品类别是水果的的利润率", answer=NL2SqlResult( - llm_result="根据问题分析得到 sql 如下: \n " - "```sql\nSELECT * FROM `超市营收明细` " - "WHERE `商品类别` = '水果'\n```", - sql="SELECT * FROM `超市营收明细` WHERE `商品类别` = '水果'")) + llm_result="根据问题分析得到 sql 如下: \n " + "```sql\nSELECT * FROM `超市营收明细` " + "WHERE `商品类别` = '水果'\n```", + sql="SELECT * FROM `超市营收明细` WHERE `商品类别` = '水果'")) session.append(session_record) query = "列出所有的商品类别" diff --git a/appbuilder/tests/test_gbi_select_table.py b/appbuilder/tests/test_gbi_select_table.py index 32485058..4abbaebc 100644 --- a/appbuilder/tests/test_gbi_select_table.py +++ b/appbuilder/tests/test_gbi_select_table.py @@ -18,7 +18,7 @@ import appbuilder from appbuilder.core.message import Message from appbuilder.core.components.gbi.basic import SessionRecord - +from appbuilder.core.components.gbi.basic import NL2SqlResult SUPER_MARKET_SCHEMA = """ ``` @@ -68,6 +68,7 @@ 回答: """ + class TestGBISelectTable(unittest.TestCase): def setUp(self): @@ -79,7 +80,7 @@ def setUp(self): self.select_table_node = \ appbuilder.SelectTable(model_name=model_name, table_descriptions={"supper_market_info": "超市营收明细表,包含超市各种信息等", - "product_sales_info": "产品销售表"}) + "product_sales_info": "产品销售表"}) def test_run_with_default_param(self): """测试 run 方法使用有效参数""" @@ -95,7 +96,6 @@ def test_run_with_prompt_template(self): """测试 run 方法中 prompt template 模版""" query = "列出超市中的所有数据" msg = Message({"query": query}) - result_message = self.select_table_node(message=msg) self.select_table_node.prompt_template = PROMPT_TEMPLATE result_message = self.select_table_node(msg) @@ -104,6 +104,26 @@ def test_run_with_prompt_template(self): self.assertEqual(result_message.content[0], "supper_market_info") self.select_table_node.prompt_template = "" + def test_run_with_session(self): + """测试 run 方法中 prompt template 模版""" + + session = list() + session_record = SessionRecord(query="列出商品类别是水果的的利润率", + answer=NL2SqlResult( + llm_result="根据问题分析得到 sql 如下: \n " + "```sql\nSELECT * FROM `超市营收明细` " + "WHERE `商品类别` = '水果'\n```", + sql="SELECT * FROM `超市营收明细` WHERE `商品类别` = '水果'")) + session.append(session_record) + + query = "列出超市中的所有数据" + msg = Message({"query": query, "session": session}) + result_message = self.select_table_node(msg) + + self.assertIsNotNone(result_message) + self.assertEqual(len(result_message.content), 1) + self.assertEqual(result_message.content[0], "supper_market_info") + if __name__ == '__main__': unittest.main() diff --git a/cookbooks/gbi.ipynb b/cookbooks/gbi.ipynb index f7b4809f..eac1b491 100644 --- a/cookbooks/gbi.ipynb +++ b/cookbooks/gbi.ipynb @@ -10,7 +10,7 @@ "# GBI\n", "\n", "## 目标\n", - "通过 GBI sdk 接口完成选表和问表的能力。\n", + "通过 GBI sdk 接口完成选表和问表的能力。 \n", "\n", "## 准备工作\n", "### 平台注册\n",