Skip to content

Commit

Permalink
chore(wren-ai-service): eval refactor for common functions (#536)
Browse files Browse the repository at this point in the history
* refactor get_data_from_wren_engine

* add limit

* move common functions for eval to eval/utils.py

* update env
  • Loading branch information
cyyeh authored Jul 18, 2024
1 parent c9d0251 commit 0a953d4
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 114 deletions.
File renamed without changes.
Empty file.
6 changes: 4 additions & 2 deletions wren-ai-service/eval/data_curation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from utils import (
DATA_SOURCES,
get_contexts_from_sqls,
get_data_from_wren_engine,
get_data_from_wren_engine_with_sqls,
get_eval_dataset_in_toml_string,
get_openai_client,
get_question_sql_pairs,
Expand Down Expand Up @@ -373,9 +373,11 @@ def on_click_remove_candidate_dataset_button(i: int):
):
st.success("SQL is valid")
data = asyncio.run(
get_data_from_wren_engine(
get_data_from_wren_engine_with_sqls(
[st.session_state["user_question_sql_pair"]["sql"]],
st.session_state["data_source"],
st.session_state["mdl_json"],
st.session_state["connection_info"],
)
)[0]
st.dataframe(
Expand Down
145 changes: 33 additions & 112 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

import aiohttp
import orjson
import sqlglot
import sqlparse
import streamlit as st
import tomlkit
from dotenv import load_dotenv
from openai import AsyncClient

# in order to import the DDLConverter class from the indexing module
# add wren-ai-service to sys.path
sys.path.append(f"{Path().parent.parent.resolve()}")
from eval.utils import add_quotes, get_contexts_from_sql, get_data_from_wren_engine
from src.pipelines.indexing.indexing import DDLConverter

load_dotenv()
Expand All @@ -37,10 +37,6 @@ def get_openai_client(
)


def add_quotes(sql: str) -> str:
return sqlglot.transpile(sql, read="trino", identify=True)[0]


async def is_sql_valid(
sql: str,
data_source: str,
Expand All @@ -49,7 +45,7 @@ async def is_sql_valid(
api_endpoint: str = WREN_IBIS_ENDPOINT,
timeout: float = TIMEOUT_SECONDS,
) -> Tuple[bool, str]:
sql = sql[:-1] if sql.endswith(";") else sql
sql = sql.rstrip(";") if sql.endswith(";") else sql
async with aiohttp.request(
"POST",
f"{api_endpoint}/v2/connector/{data_source}/query?dryRun=true",
Expand Down Expand Up @@ -99,95 +95,28 @@ async def get_validated_question_sql_pairs(
]


async def get_sql_analysis(
sql: str,
mdl_json: dict,
api_endpoint: str = WREN_ENGINE_ENDPOINT,
timeout: float = TIMEOUT_SECONDS,
) -> List[dict]:
sql = sql[:-1] if sql.endswith(";") else sql
async with aiohttp.request(
"GET",
f"{api_endpoint}/v1/analysis/sql",
json={
"sql": add_quotes(sql),
"manifest": mdl_json,
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
return await response.json()


async def get_contexts_from_sqls(
sqls: list[str],
mdl_json: dict,
) -> list[str]:
def _compose_contexts_of_select_type(select_items: list[dict]):
return [
f'{expr_source['sourceDataset']}.{expr_source['expression']}'
for select_item in select_items
for expr_source in select_item["exprSources"]
]

def _compose_contexts_of_filter_type(filter: dict):
contexts = []
if filter["type"] == "EXPR":
contexts += [
f'{expr_source["sourceDataset"]}.{expr_source["expression"]}'
for expr_source in filter["exprSources"]
]
elif filter["type"] in ("AND", "OR"):
contexts += _compose_contexts_of_filter_type(filter["left"])
contexts += _compose_contexts_of_filter_type(filter["right"])

return contexts

def _compose_contexts_of_groupby_type(groupby_keys: list[list[dict]]):
contexts = []
for groupby_key_list in groupby_keys:
contexts += [
f'{expr_source["sourceDataset"]}.{expr_source["expression"]}'
for groupby_key in groupby_key_list
for expr_source in groupby_key["exprSources"]
]
return contexts

def _compose_contexts_of_sorting_type(sortings: list[dict]):
return [
f'{expr_source["sourceDataset"]}.{expr_source["expression"]}'
for sorting in sortings
for expr_source in sorting["exprSources"]
]

def _get_contexts_from_sql_analysis_results(sql_analysis_results: list[dict]):
contexts = []
for result in sql_analysis_results:
if "selectItems" in result:
contexts += _compose_contexts_of_select_type(result["selectItems"])
if "filter" in result:
contexts += _compose_contexts_of_filter_type(result["filter"])
if "groupByKeys" in result:
contexts += _compose_contexts_of_groupby_type(result["groupByKeys"])
if "sortings" in result:
contexts += _compose_contexts_of_sorting_type(result["sortings"])

# print(
# f'SQL ANALYSIS RESULTS: {orjson.dumps(sql_analysis_results, option=orjson.OPT_INDENT_2).decode("utf-8")}'
# )
# print(f"CONTEXTS: {sorted(set(contexts))}")
# print("\n\n")

return sorted(set(contexts))

api_endpoint: str = WREN_ENGINE_ENDPOINT,
timeout: float = TIMEOUT_SECONDS,
) -> list[list[str]]:
async with aiohttp.ClientSession():
tasks = []
for sql in sqls:
task = asyncio.ensure_future(get_sql_analysis(sql, mdl_json))
task = asyncio.ensure_future(
get_contexts_from_sql(
sql,
mdl_json,
api_endpoint,
timeout,
)
)
tasks.append(task)

results = await asyncio.gather(*tasks)

return [_get_contexts_from_sql_analysis_results(result) for result in results]
return results


async def get_question_sql_pairs(
Expand Down Expand Up @@ -256,7 +185,9 @@ async def get_question_sql_pairs(
)
sqls = [question_sql_pair["sql"] for question_sql_pair in question_sql_pairs]
contexts = await get_contexts_from_sqls(sqls, mdl_json)
sqls_data = await get_data_from_wren_engine(sqls, data_source)
sqls_data = await get_data_from_wren_engine_with_sqls(
sqls, data_source, mdl_json, connection_info
)
return [
{**quesiton_sql_pair, "context": context, "data": sql_data}
for quesiton_sql_pair, context, sql_data in zip(
Expand All @@ -276,40 +207,30 @@ def prettify_sql(sql: str) -> str:
)


async def get_data_from_wren_engine(
async def get_data_from_wren_engine_with_sqls(
sqls: List[str],
data_source: str,
mdl_json: dict,
connection_info: dict,
api_endpoint: str = WREN_IBIS_ENDPOINT,
data_sources: list[str] = DATA_SOURCES,
timeout: float = TIMEOUT_SECONDS,
) -> List[dict]:
assert data_source in data_sources, f"Invalid data source: {data_source}"

async def _get_data(sql: str, data_source: str):
async with aiohttp.request(
"POST",
f"{api_endpoint}/v2/connector/{data_source}/query",
json={
"sql": add_quotes(sql),
"manifestStr": base64.b64encode(
orjson.dumps(st.session_state["mdl_json"])
).decode(),
"connectionInfo": st.session_state["connection_info"],
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status != 200:
return {"data": [], "columns": []}

data = await response.json()
column_names = [f"{i}_{col}" for i, col in enumerate(data["columns"])]

return {"data": data["data"], "columns": column_names}
assert data_source in DATA_SOURCES, f"Invalid data source: {data_source}"

async with aiohttp.ClientSession():
tasks = []
for sql in sqls:
task = asyncio.ensure_future(_get_data(sql, data_source))
task = asyncio.ensure_future(
get_data_from_wren_engine(
sql,
data_source,
mdl_json,
connection_info,
api_endpoint,
timeout,
limit=50,
)
)
tasks.append(task)

return await asyncio.gather(*tasks)
Expand Down
121 changes: 121 additions & 0 deletions wren-ai-service/eval/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import base64
from typing import List, Optional

import aiohttp
import orjson
import sqlglot


def add_quotes(sql: str) -> str:
return sqlglot.transpile(sql, read="trino", identify=True)[0]


async def get_data_from_wren_engine(
sql: str,
data_source: str,
mdl_json: dict,
connection_info: dict,
api_endpoint: str,
timeout: float,
limit: Optional[int] = None,
):
url = f"{api_endpoint}/v2/connector/{data_source}/query"
if limit is not None:
url += f"?limit={limit}"

async with aiohttp.request(
"POST",
url,
json={
"sql": add_quotes(sql),
"manifestStr": base64.b64encode(orjson.dumps(mdl_json)).decode(),
"connectionInfo": connection_info,
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status != 200:
return {"data": [], "columns": []}

data = await response.json()
column_names = [f"{i}_{col}" for i, col in enumerate(data["columns"])]

return {"data": data["data"], "columns": column_names}


async def get_contexts_from_sql(
sql: str,
mdl_json: dict,
api_endpoint: str,
timeout: float,
) -> list[str]:
def _get_contexts_from_sql_analysis_results(sql_analysis_results: list[dict]):
def _compose_contexts_of_select_type(select_items: list[dict]):
return [
f'{expr_source['sourceDataset']}.{expr_source['expression']}'
for select_item in select_items
for expr_source in select_item["exprSources"]
]

def _compose_contexts_of_filter_type(filter: dict):
contexts = []
if filter["type"] == "EXPR":
contexts += [
f'{expr_source["sourceDataset"]}.{expr_source["expression"]}'
for expr_source in filter["exprSources"]
]
elif filter["type"] in ("AND", "OR"):
contexts += _compose_contexts_of_filter_type(filter["left"])
contexts += _compose_contexts_of_filter_type(filter["right"])

return contexts

def _compose_contexts_of_groupby_type(groupby_keys: list[list[dict]]):
contexts = []
for groupby_key_list in groupby_keys:
contexts += [
f'{expr_source["sourceDataset"]}.{expr_source["expression"]}'
for groupby_key in groupby_key_list
for expr_source in groupby_key["exprSources"]
]
return contexts

def _compose_contexts_of_sorting_type(sortings: list[dict]):
return [
f'{expr_source["sourceDataset"]}.{expr_source["expression"]}'
for sorting in sortings
for expr_source in sorting["exprSources"]
]

contexts = []
for result in sql_analysis_results:
if "selectItems" in result:
contexts += _compose_contexts_of_select_type(result["selectItems"])
if "filter" in result:
contexts += _compose_contexts_of_filter_type(result["filter"])
if "groupByKeys" in result:
contexts += _compose_contexts_of_groupby_type(result["groupByKeys"])
if "sortings" in result:
contexts += _compose_contexts_of_sorting_type(result["sortings"])

return sorted(set(contexts))

async def _get_sql_analysis(
sql: str,
mdl_json: dict,
api_endpoint: str,
timeout: float,
) -> List[dict]:
sql = sql.rstrip(";") if sql.endswith(";") else sql
async with aiohttp.request(
"GET",
f"{api_endpoint}/v1/analysis/sql",
json={
"sql": add_quotes(sql),
"manifest": mdl_json,
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
return await response.json()

contexts = await _get_sql_analysis(sql, mdl_json, api_endpoint, timeout)
return _get_contexts_from_sql_analysis_results(contexts)

0 comments on commit 0a953d4

Please sign in to comment.