Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Developing #12

Merged
merged 7 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,24 @@ docker run -itd --name mongo -v /{path_of_mongo_data}:/data/db -p 27017:27017 mo

环境变量

下载https://github.com/hitsz-ids/airda/blob/main/.env.template文件,自定义embedding模型,mongo配置,以及openai配置
下载[.env.template](https://github.com/hitsz-ids/airda/blob/main/.env.template)自定义embedding模型,mongo配置,以及openai配置

```
airda env load -p {your_path}/.env_template
```

日志文件(非必须)

下载https://github.com/hitsz-ids/airda/blob/main/log_config.yml.template文件,自定义日志配置
下载[log_config.yml.template](https://github.com/hitsz-ids/airda/blob/main/log_config.yml.template),自定义日志配置

```
airda log load -p {your_path}/log_config.yml.template
```

Embedding Model

airda默认使用[stella-large-zh-v2](https://huggingface.co/infgrad/stella-large-zh-v2)模型, 模型默认下载到~/.cache/huggingface/hub/路径,目录下没有需要手动下载



### 相关配置命令
Expand Down Expand Up @@ -100,9 +104,8 @@ airda run cli -n {datasource_name}

我们欢迎各种贡献和建议,共同努力,使本项目更上一层楼!麻烦遵循以下步骤:

- **步骤1:** 如果您想添加任何额外的功能、增强功能或在使用过程中遇到任何问题,请发布一个 [问题](https://github.com/hitsz-ids/SQLAgent/issues) 。如果您能遵循 [问题模板](https://github.com/hitsz-ids/SQLAgent/issues/1) 我们将不胜感激。问题将在那里被讨论和分配。
- **步骤2:** 无论何时,当一个问题被分配后,您都可以按照 [PR模板](https://github.com/hitsz-ids/SQLAgent/pulls) 创建一个 [拉取请求](https://github.com/hitsz-ids/SQLAgent/pulls) 进行贡献。您也可以认领任何公开的问题。共同努力,我们可以使airda变得更好!
- **步骤1:** 如果您想添加任何额外的功能、增强功能或在使用过程中遇到任何问题,请发布一个 [问题](https://github.com/hitsz-ids/airda/issues) 。如果您能遵循 [问题模板](https://github.com/hitsz-ids/aird/issues/1) 我们将不胜感激。问题将在那里被讨论和分配。
- **步骤2:** 无论何时,当一个问题被分配后,您都可以按照 [PR模板](https://github.com/hitsz-ids/aird/pulls) 创建一个 [拉取请求](https://github.com/hitsz-ids/aird/pulls) 进行贡献。您也可以认领任何公开的问题。共同努力,我们可以使airda变得更好!
- **步骤3:** 在审查和讨论后,PR将被合并或迭代。感谢您的贡献!

在您开始之前,我们强烈建议您花一点时间检查 [这里](https://github.com/hitsz-ids/SQLAgent/blob/developing/CONTRIBUTING.md) 再进行贡献。

在您开始之前,我们强烈建议您花一点时间检查 [这里](https://github.com/hitsz-ids/aird/blob/developing/CONTRIBUTING.md) 再进行贡献。
6 changes: 3 additions & 3 deletions airda/cli/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from airda.agent.storage.entity.datasource import Datasource, Kind
from airda.agent.storage.repositories.datasource_repository import DatasourceRepository
from airda.connector.mysql import MysqlConnector
from airda.server.agent_server import DataAgentServer
from airda.server.agent_server.airda_server import AirdaServer

style = Style.from_dict(
{
Expand Down Expand Up @@ -108,8 +108,8 @@ async def execute():
help="服务端口号",
)
def server(port: int):
data_agent_server = DataAgentServer(port=port)
data_agent_server.run_server()
airda_server = AirdaServer(port=port)
airda_server.run_server()
pass


Expand Down
1 change: 1 addition & 0 deletions airda/connector/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,4 @@ def query_schema(self):
table_comment=table[1],
)
self.context.sync_instruction(instruction)
cursor.close()
40 changes: 0 additions & 40 deletions airda/server/agent_server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +0,0 @@
import fastapi
from overrides import overrides

from airda.agent.env import DataAgentEnv
from airda.server import WebFrameworkServer
from airda.server.api.api import APIImpl
from airda.server.protocol import ChatCompletionRequest


class DataAgentServer(WebFrameworkServer):
def __init__(self, host="0.0.0.0", port=8888):
super().__init__(host, port)
self.router = None

def init_api(self):
return APIImpl()

@overrides
def create_app(self):
return fastapi.FastAPI(debug=True)

@overrides
def run_server(self):
import uvicorn

uvicorn.run(self.app, host=self.host, port=self.port, log_level="info")

@overrides
def add_routes(self):
self.router = fastapi.APIRouter()
self.router.add_api_route(
"/v1/chat/completions",
self.create_completion,
methods=["POST"],
tags=["chat completions"],
)
self.app.include_router(self.router)

async def create_completion(self, request: ChatCompletionRequest):
return await self._api.create_completion(request)
48 changes: 48 additions & 0 deletions airda/server/agent_server/airda_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import fastapi
from overrides import overrides

from airda.server import WebFrameworkServer
from airda.server.api.api import APIImpl
from airda.server.protocol import AddDatasourceRequest, ChatCompletionRequest


class AirdaServer(WebFrameworkServer):
def __init__(self, host="0.0.0.0", port=8888):
super().__init__(host, port)
self.router = None

def init_api(self):
return APIImpl()

@overrides
def create_app(self):
return fastapi.FastAPI(debug=True)

@overrides
def run_server(self):
import uvicorn

uvicorn.run(self.app, host=self.host, port=self.port, log_level="info")

@overrides
def add_routes(self):
self.router = fastapi.APIRouter()
self.router.add_api_route(
"/v1/chat/completions",
self.create_completion,
methods=["POST"],
tags=["chat completions"],
)
self.router.add_api_route(
"/v1/datasource/add",
self.create_completion,
methods=["POST"],
tags=["datasource add"],
)
self.app.include_router(self.router)

async def create_completion(self, request: ChatCompletionRequest):
return await self._api.create_completion(request)

def add_datasource(self, request: AddDatasourceRequest):
return self._api.add_datasource(request)
6 changes: 5 additions & 1 deletion airda/server/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from abc import ABC, abstractmethod

from airda.server.protocol import ChatCompletionRequest
from airda.server.protocol import AddDatasourceRequest, ChatCompletionRequest


class API(ABC):
@abstractmethod
async def create_completion(self, request: ChatCompletionRequest):
pass

@abstractmethod
async def add_datasource(self, request: AddDatasourceRequest):
pass
42 changes: 37 additions & 5 deletions airda/server/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
from typing import AsyncGenerator

from fastapi.responses import JSONResponse, StreamingResponse
from overrides import override
from overrides import overrides

from airda.agent.agent import DataAgent
from airda.agent.data_agent_context import DataAgentContext
from airda.agent.exception.already_exists_error import AlreadyExistsError
from airda.agent.planner.data_agent_planner_params import DataAgentPlannerParams
from airda.agent.storage import StorageKey
from airda.agent.storage.entity.datasource import Datasource, Kind
from airda.agent.storage.repositories.datasource_repository import DatasourceRepository
from airda.server.api import API
from airda.server.protocol import (
AddDatasourceRequest,
ChatCompletionRequest,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
Expand All @@ -22,24 +27,51 @@

class APIImpl(API):
_cache: dict[str, StreamingResponse] = {}
agent: DataAgentContext
context: DataAgentContext

def __init__(self):
super().__init__()
self.agent = DataAgent().run()
self.context = DataAgent().run()

@override
@overrides
async def create_completion(self, request: ChatCompletionRequest):
async def stream_generator() -> AsyncGenerator[str, None]:
try:
pipeline = self.agent.get_planner().plan(DataAgentPlannerParams(**vars(request)))
pipeline = self.context.get_planner().plan(DataAgentPlannerParams(**vars(request)))
async for item in pipeline.execute():
yield f"data: {make_stream_data(content=item)}\n\n"
except Exception as e:
print(e)

return StreamingResponse(stream_generator(), media_type="text/event-stream")

@overrides
def add_datasource(self, request: AddDatasourceRequest):
kind = Kind.getKind(request.kind)
if kind is None:
# output_colored_text(f"不支持的数据源类型[{kind}], PS: 支持类型: [{Kind.MYSQL.value}]", "error")
message = f"不支持的数据源类型[{kind}], PS: 支持类型: [{Kind.MYSQL.value}]"
return JSONResponse(ErrorResponse(message=message, code=-1).dict())
datasource_repository = self.context.get_repository(StorageKey.DATASOURCE).convert(
DatasourceRepository
)
try:
datasource_repository.add(
Datasource(
name=request.name,
host=request.host,
port=request.port,
database=request.database,
kind=kind,
username=request.username,
password=request.password,
)
)
# output_colored_text("执行成功", "success")
except AlreadyExistsError:
pass
# output_colored_text(f"执行失败, [{name}]数据源已存在", "error")


def make_stream_data(
content: str | dict | list,
Expand Down
16 changes: 10 additions & 6 deletions airda/server/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@ class ErrorResponse(BaseModel):

class ChatCompletionRequest(BaseModel):
question: str
datasource_id: str
datasource_name: str


class AddDatasourceRequest(BaseModel):
name: str
host: str
port: int
database: str
knowledge: str
session_id: str
sql_type: str = "mysql"
file_name: str
file_id: str
kind: str
username: str | None
password: str | None


class DeltaMessage(BaseModel):
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ torch = "2.0.1"
pymongo = "^4.6.2"
prompt-toolkit = "^3.0.43"
pyyaml = "^6.0.1"
mysql-connector-python = "^8.3.0"
fastapi = "0.99.0"

[tool.poetry.scripts]
Expand Down
Loading