diff --git a/wren-ai-service/eval/.env.example b/wren-ai-service/eval/.env.example index f1211a108b..7f3a06b87a 100644 --- a/wren-ai-service/eval/.env.example +++ b/wren-ai-service/eval/.env.example @@ -1,4 +1,6 @@ OPENAI_API_KEY= OPENAI_GENERATION_MODEL=gpt-3.5-turbo -WREN_ENGINE_ENDPOINT=http://localhost:8080 -WREN_UI_ENDPOINT=http://localhost:3000 \ No newline at end of file +WREN_IBIS_ENDPOINT=http://localhost:8000 +bigquery.project-id= +bigquery.dataset-id= +bigquery.credentials-key= \ No newline at end of file diff --git a/wren-ai-service/eval/data_curation/app.py b/wren-ai-service/eval/data_curation/app.py index bbeda3b644..3171d2d868 100644 --- a/wren-ai-service/eval/data_curation/app.py +++ b/wren-ai-service/eval/data_curation/app.py @@ -1,19 +1,22 @@ import asyncio +import os +import re +import uuid from datetime import datetime +import orjson import streamlit as st import tomlkit from openai import AsyncClient from streamlit_tags import st_tags from utils import ( + DATA_SOURCES, get_contexts_from_sqls, - get_current_manifest, get_eval_dataset_in_toml_string, get_llm_client, get_question_sql_pairs, is_sql_valid, prettify_sql, - show_er_diagram, ) st.set_page_config(layout="wide") @@ -24,6 +27,8 @@ # session states +if "deployment_id" not in st.session_state: + st.session_state["deployment_id"] = str(uuid.uuid4()) if "mdl_json" not in st.session_state: st.session_state["mdl_json"] = None if "llm_question_sql_pairs" not in st.session_state: @@ -32,11 +37,15 @@ st.session_state["user_question_sql_pair"] = {} if "candidate_dataset" not in st.session_state: st.session_state["candidate_dataset"] = [] +if "data_source" not in st.session_state: + st.session_state["data_source"] = None +if "connection_info" not in st.session_state: + st.session_state["connection_info"] = None # widget callbacks def on_change_upload_eval_dataset(): - doc = tomlkit.parse(st.session_state.uploaded_file.getvalue().decode("utf-8")) + doc = tomlkit.parse(st.session_state.uploaded_eval_file.getvalue().decode("utf-8")) assert ( doc["mdl"] == st.session_state["mdl_json"] ), "The model in the uploaded dataset is different from the deployed model" @@ -50,6 +59,31 @@ def on_click_generate_question_sql_pairs(llm_client: AsyncClient): ) +def on_click_setup_uploaded_file(): + uploaded_file = st.session_state.get("uploaded_mdl_file") + match = re.match( + r".+_(" + "|".join(DATA_SOURCES) + r")_mdl\.json$", + uploaded_file.name, + ) + if not match: + st.error( + f"the file name must be [xxx]_[datasource]_mdl.json, now we support these datasources: {DATA_SOURCES}" + ) + st.stop() + + data_source = match.group(1) + st.session_state["data_source"] = data_source + st.session_state["mdl_json"] = orjson.loads( + uploaded_file.getvalue().decode("utf-8") + ) + + st.session_state["connection_info"] = { + "project_id": os.getenv("bigquery.project-id"), + "dataset_id": os.getenv("bigquery.dataset-id"), + "credentials": os.getenv("bigquery.credentials-key"), + } + + def on_change_sql(i: int, key: str): sql = st.session_state[key] @@ -170,20 +204,17 @@ def on_click_remove_candidate_dataset_button(i: int): st.file_uploader( "Upload Evaluation Dataset", type="toml", - key="uploaded_file", + key="uploaded_eval_file", on_change=on_change_upload_eval_dataset, ) -if manifest := get_current_manifest(): - st.session_state["mdl_json"] = manifest - st.markdown("### Deployed Model Information") - st.json(st.session_state["mdl_json"], expanded=False) - show_er_diagram( - st.session_state["mdl_json"]["models"], - st.session_state["mdl_json"]["relationships"], - ) - st.markdown("---") +st.file_uploader( + f"Upload an MDL json file, and the file name must be [xxx]_[datasource]_mdl.json, now we support these datasources: {DATA_SOURCES}", + type="json", + key="uploaded_mdl_file", + on_change=on_click_setup_uploaded_file, +) if st.session_state["mdl_json"] is not None: col1, col2 = st.columns(2) diff --git a/wren-ai-service/eval/data_curation/utils.py b/wren-ai-service/eval/data_curation/utils.py index da115ab1f4..ec67a29cb9 100644 --- a/wren-ai-service/eval/data_curation/utils.py +++ b/wren-ai-service/eval/data_curation/utils.py @@ -1,4 +1,5 @@ import asyncio +import base64 import os import re from datetime import datetime @@ -6,7 +7,6 @@ import aiohttp import orjson -import requests import sqlglot import sqlparse import streamlit as st @@ -16,14 +16,8 @@ load_dotenv() - -def get_current_manifest() -> Tuple[str, dict]: - response = requests.get( - f"{os.getenv("WREN_ENGINE_ENDPOINT", "http://localhost:8080")}/v1/mdl", - ) - - assert response.status_code == 200 - return response.json() +WREN_IBIS_ENDPOINT = os.getenv("WREN_IBIS_ENDPOINT", "http://localhost:8000") +DATA_SOURCES = ["bigquery"] def get_llm_client() -> AsyncClient: @@ -47,22 +41,21 @@ async def is_sql_valid(sql: str) -> Tuple[bool, str]: sql = sql[:-1] if sql.endswith(";") else sql async with aiohttp.request( "POST", - f'{os.getenv("WREN_UI_ENDPOINT", "http://localhost:3000")}/api/graphql', + f'{WREN_IBIS_ENDPOINT}/v2/ibis/{st.session_state['data_source']}/query?dryRun=true', json={ - "query": "mutation PreviewSql($data: PreviewSQLDataInput) { previewSql(data: $data) }", - "variables": { - "data": { - "dryRun": True, - "limit": 1, - "sql": remove_limit_statement(add_quotes(sql)), - } - }, + "sql": add_quotes(sql), + "manifestStr": base64.b64encode( + orjson.dumps(st.session_state["mdl_json"]) + ).decode(), + "connectionInfo": st.session_state["connection_info"], }, + timeout=aiohttp.ClientTimeout(total=60), ) as response: - res = await response.json() - if res.get("data"): + if response.status == 204: return True, None - return False, res.get("errors", [{}])[0].get("message", "Unknown error") + res = await response.text() + + return False, res async def get_validated_question_sql_pairs( @@ -206,10 +199,15 @@ def _format(view: dict[str, Any]) -> str: return [_format(view) for view in views] + models = mdl_json.get("models", []) + relationships = mdl_json.get("relationships", []) + metrics = mdl_json.get("metrics", []) + views = mdl_json.get("views", []) + ddl_commands = ( - _convert_models_and_relationships(mdl_json["models"], mdl_json["relationships"]) - + _convert_metrics(mdl_json["metrics"]) - + _convert_views(mdl_json["views"]) + _convert_models_and_relationships(models, relationships) + + _convert_metrics(metrics) + + _convert_views(views) ) return "\n\n".join(ddl_commands) diff --git a/wren-ai-service/eval/poetry.lock b/wren-ai-service/eval/poetry.lock index 487cf9411d..cb8d9b69f0 100644 --- a/wren-ai-service/eval/poetry.lock +++ b/wren-ai-service/eval/poetry.lock @@ -1816,18 +1816,18 @@ files = [ [[package]] name = "sqlglot" -version = "25.1.0" +version = "25.3.0" description = "An easily customizable SQL parser and transpiler" optional = false python-versions = ">=3.7" files = [ - {file = "sqlglot-25.1.0-py3-none-any.whl", hash = "sha256:a9b23f0ea455ca3a7bc8daf4b404f9a19d4ab201b0220907e121f935c8a01013"}, - {file = "sqlglot-25.1.0.tar.gz", hash = "sha256:1df2c5c0ef56961c4269a626414d3c349340f2aa329dffdb9d3f90f2eef50dbf"}, + {file = "sqlglot-25.3.0-py3-none-any.whl", hash = "sha256:1cca732e7c2ba4fe86665d8e05d7af49b3f2a5e7ec45a8f3ab0649d4207351b9"}, + {file = "sqlglot-25.3.0.tar.gz", hash = "sha256:c4ce5e38148c29f3bb19d8dcf1bfaaa4833eb283e9e6771e39a7e90dd6b79a0c"}, ] [package.extras] dev = ["duckdb (>=0.6)", "maturin (>=1.4,<2.0)", "mypy", "pandas", "pandas-stubs", "pdoc", "pre-commit", "python-dateutil", "ruff (==0.4.3)", "types-python-dateutil", "typing-extensions"] -rs = ["sqlglotrs (==0.2.6)"] +rs = ["sqlglotrs (==0.2.7)"] [[package]] name = "sqlparse" @@ -1895,13 +1895,13 @@ streamlit = ">=0.63" [[package]] name = "tenacity" -version = "8.3.0" +version = "8.4.2" description = "Retry code until it succeeds" optional = false python-versions = ">=3.8" files = [ - {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"}, - {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"}, + {file = "tenacity-8.4.2-py3-none-any.whl", hash = "sha256:9e6f7cf7da729125c7437222f8a522279751cdfbe6b67bfe64f75d3a348661b2"}, + {file = "tenacity-8.4.2.tar.gz", hash = "sha256:cd80a53a79336edba8489e767f729e4f391c896956b57140b5d7511a64bbd3ef"}, ] [package.extras] @@ -2005,13 +2005,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.1" +version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, - {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras]