Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Jul 1, 2024
1 parent 6474eed commit cc7cad3
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 49 deletions.
6 changes: 4 additions & 2 deletions wren-ai-service/eval/.env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
OPENAI_API_KEY=
OPENAI_GENERATION_MODEL=gpt-3.5-turbo
WREN_ENGINE_ENDPOINT=http://localhost:8080
WREN_UI_ENDPOINT=http://localhost:3000
WREN_IBIS_ENDPOINT=http://localhost:8000
bigquery.project-id=
bigquery.dataset-id=
bigquery.credentials-key=
57 changes: 44 additions & 13 deletions wren-ai-service/eval/data_curation/app.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import asyncio
import os
import re
import uuid
from datetime import datetime

import orjson
import streamlit as st
import tomlkit
from openai import AsyncClient
from streamlit_tags import st_tags
from utils import (
DATA_SOURCES,
get_contexts_from_sqls,
get_current_manifest,
get_eval_dataset_in_toml_string,
get_llm_client,
get_question_sql_pairs,
is_sql_valid,
prettify_sql,
show_er_diagram,
)

st.set_page_config(layout="wide")
Expand All @@ -24,6 +27,8 @@


# session states
if "deployment_id" not in st.session_state:
st.session_state["deployment_id"] = str(uuid.uuid4())
if "mdl_json" not in st.session_state:
st.session_state["mdl_json"] = None
if "llm_question_sql_pairs" not in st.session_state:
Expand All @@ -32,11 +37,15 @@
st.session_state["user_question_sql_pair"] = {}
if "candidate_dataset" not in st.session_state:
st.session_state["candidate_dataset"] = []
if "data_source" not in st.session_state:
st.session_state["data_source"] = None
if "connection_info" not in st.session_state:
st.session_state["connection_info"] = None


# widget callbacks
def on_change_upload_eval_dataset():
doc = tomlkit.parse(st.session_state.uploaded_file.getvalue().decode("utf-8"))
doc = tomlkit.parse(st.session_state.uploaded_eval_file.getvalue().decode("utf-8"))
assert (
doc["mdl"] == st.session_state["mdl_json"]
), "The model in the uploaded dataset is different from the deployed model"
Expand All @@ -50,6 +59,31 @@ def on_click_generate_question_sql_pairs(llm_client: AsyncClient):
)


def on_click_setup_uploaded_file():
uploaded_file = st.session_state.get("uploaded_mdl_file")
match = re.match(
r".+_(" + "|".join(DATA_SOURCES) + r")_mdl\.json$",
uploaded_file.name,
)
if not match:
st.error(
f"the file name must be [xxx]_[datasource]_mdl.json, now we support these datasources: {DATA_SOURCES}"
)
st.stop()

data_source = match.group(1)
st.session_state["data_source"] = data_source
st.session_state["mdl_json"] = orjson.loads(
uploaded_file.getvalue().decode("utf-8")
)

st.session_state["connection_info"] = {
"project_id": os.getenv("bigquery.project-id"),
"dataset_id": os.getenv("bigquery.dataset-id"),
"credentials": os.getenv("bigquery.credentials-key"),
}


def on_change_sql(i: int, key: str):
sql = st.session_state[key]

Expand Down Expand Up @@ -170,20 +204,17 @@ def on_click_remove_candidate_dataset_button(i: int):
st.file_uploader(
"Upload Evaluation Dataset",
type="toml",
key="uploaded_file",
key="uploaded_eval_file",
on_change=on_change_upload_eval_dataset,
)


if manifest := get_current_manifest():
st.session_state["mdl_json"] = manifest
st.markdown("### Deployed Model Information")
st.json(st.session_state["mdl_json"], expanded=False)
show_er_diagram(
st.session_state["mdl_json"]["models"],
st.session_state["mdl_json"]["relationships"],
)
st.markdown("---")
st.file_uploader(
f"Upload an MDL json file, and the file name must be [xxx]_[datasource]_mdl.json, now we support these datasources: {DATA_SOURCES}",
type="json",
key="uploaded_mdl_file",
on_change=on_click_setup_uploaded_file,
)

if st.session_state["mdl_json"] is not None:
col1, col2 = st.columns(2)
Expand Down
46 changes: 22 additions & 24 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
import base64
import os
import re
from datetime import datetime
from typing import Any, List, Tuple

import aiohttp
import orjson
import requests
import sqlglot
import sqlparse
import streamlit as st
Expand All @@ -16,14 +16,8 @@

load_dotenv()


def get_current_manifest() -> Tuple[str, dict]:
response = requests.get(
f"{os.getenv("WREN_ENGINE_ENDPOINT", "http://localhost:8080")}/v1/mdl",
)

assert response.status_code == 200
return response.json()
WREN_IBIS_ENDPOINT = os.getenv("WREN_IBIS_ENDPOINT", "http://localhost:8000")
DATA_SOURCES = ["bigquery"]


def get_llm_client() -> AsyncClient:
Expand All @@ -47,22 +41,21 @@ async def is_sql_valid(sql: str) -> Tuple[bool, str]:
sql = sql[:-1] if sql.endswith(";") else sql
async with aiohttp.request(
"POST",
f'{os.getenv("WREN_UI_ENDPOINT", "http://localhost:3000")}/api/graphql',
f'{WREN_IBIS_ENDPOINT}/v2/ibis/{st.session_state['data_source']}/query?dryRun=true',
json={
"query": "mutation PreviewSql($data: PreviewSQLDataInput) { previewSql(data: $data) }",
"variables": {
"data": {
"dryRun": True,
"limit": 1,
"sql": remove_limit_statement(add_quotes(sql)),
}
},
"sql": add_quotes(sql),
"manifestStr": base64.b64encode(
orjson.dumps(st.session_state["mdl_json"])
).decode(),
"connectionInfo": st.session_state["connection_info"],
},
timeout=aiohttp.ClientTimeout(total=60),
) as response:
res = await response.json()
if res.get("data"):
if response.status == 204:
return True, None
return False, res.get("errors", [{}])[0].get("message", "Unknown error")
res = await response.text()

return False, res


async def get_validated_question_sql_pairs(
Expand Down Expand Up @@ -206,10 +199,15 @@ def _format(view: dict[str, Any]) -> str:

return [_format(view) for view in views]

models = mdl_json.get("models", [])
relationships = mdl_json.get("relationships", [])
metrics = mdl_json.get("metrics", [])
views = mdl_json.get("views", [])

ddl_commands = (
_convert_models_and_relationships(mdl_json["models"], mdl_json["relationships"])
+ _convert_metrics(mdl_json["metrics"])
+ _convert_views(mdl_json["views"])
_convert_models_and_relationships(models, relationships)
+ _convert_metrics(metrics)
+ _convert_views(views)
)
return "\n\n".join(ddl_commands)

Expand Down
20 changes: 10 additions & 10 deletions wren-ai-service/eval/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit cc7cad3

Please sign in to comment.