diff --git a/wren-ai-service/eval/data_curation/.env.example b/wren-ai-service/eval/.env.example similarity index 100% rename from wren-ai-service/eval/data_curation/.env.example rename to wren-ai-service/eval/.env.example diff --git a/wren-ai-service/eval/__init__.py b/wren-ai-service/eval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/wren-ai-service/eval/data_curation/app.py b/wren-ai-service/eval/data_curation/app.py index 4d49b5ddd..0872ca64e 100644 --- a/wren-ai-service/eval/data_curation/app.py +++ b/wren-ai-service/eval/data_curation/app.py @@ -13,7 +13,7 @@ from utils import ( DATA_SOURCES, get_contexts_from_sqls, - get_data_from_wren_engine, + get_data_from_wren_engine_with_sqls, get_eval_dataset_in_toml_string, get_openai_client, get_question_sql_pairs, @@ -373,9 +373,11 @@ def on_click_remove_candidate_dataset_button(i: int): ): st.success("SQL is valid") data = asyncio.run( - get_data_from_wren_engine( + get_data_from_wren_engine_with_sqls( [st.session_state["user_question_sql_pair"]["sql"]], st.session_state["data_source"], + st.session_state["mdl_json"], + st.session_state["connection_info"], ) )[0] st.dataframe( diff --git a/wren-ai-service/eval/data_curation/utils.py b/wren-ai-service/eval/data_curation/utils.py index 0b6ab8166..e110700f9 100644 --- a/wren-ai-service/eval/data_curation/utils.py +++ b/wren-ai-service/eval/data_curation/utils.py @@ -8,15 +8,15 @@ import aiohttp import orjson -import sqlglot import sqlparse import streamlit as st import tomlkit from dotenv import load_dotenv from openai import AsyncClient -# in order to import the DDLConverter class from the indexing module +# add wren-ai-service to sys.path sys.path.append(f"{Path().parent.parent.resolve()}") +from eval.utils import add_quotes, get_contexts_from_sql, get_data_from_wren_engine from src.pipelines.indexing.indexing import DDLConverter load_dotenv() @@ -37,10 +37,6 @@ def get_openai_client( ) -def add_quotes(sql: str) -> str: - return sqlglot.transpile(sql, read="trino", identify=True)[0] - - async def is_sql_valid( sql: str, data_source: str, @@ -49,7 +45,7 @@ async def is_sql_valid( api_endpoint: str = WREN_IBIS_ENDPOINT, timeout: float = TIMEOUT_SECONDS, ) -> Tuple[bool, str]: - sql = sql[:-1] if sql.endswith(";") else sql + sql = sql.rstrip(";") if sql.endswith(";") else sql async with aiohttp.request( "POST", f"{api_endpoint}/v2/connector/{data_source}/query?dryRun=true", @@ -99,95 +95,28 @@ async def get_validated_question_sql_pairs( ] -async def get_sql_analysis( - sql: str, - mdl_json: dict, - api_endpoint: str = WREN_ENGINE_ENDPOINT, - timeout: float = TIMEOUT_SECONDS, -) -> List[dict]: - sql = sql[:-1] if sql.endswith(";") else sql - async with aiohttp.request( - "GET", - f"{api_endpoint}/v1/analysis/sql", - json={ - "sql": add_quotes(sql), - "manifest": mdl_json, - }, - timeout=aiohttp.ClientTimeout(total=timeout), - ) as response: - return await response.json() - - async def get_contexts_from_sqls( sqls: list[str], mdl_json: dict, -) -> list[str]: - def _compose_contexts_of_select_type(select_items: list[dict]): - return [ - f'{expr_source['sourceDataset']}.{expr_source['expression']}' - for select_item in select_items - for expr_source in select_item["exprSources"] - ] - - def _compose_contexts_of_filter_type(filter: dict): - contexts = [] - if filter["type"] == "EXPR": - contexts += [ - f'{expr_source["sourceDataset"]}.{expr_source["expression"]}' - for expr_source in filter["exprSources"] - ] - elif filter["type"] in ("AND", "OR"): - contexts += _compose_contexts_of_filter_type(filter["left"]) - contexts += _compose_contexts_of_filter_type(filter["right"]) - - return contexts - - def _compose_contexts_of_groupby_type(groupby_keys: list[list[dict]]): - contexts = [] - for groupby_key_list in groupby_keys: - contexts += [ - f'{expr_source["sourceDataset"]}.{expr_source["expression"]}' - for groupby_key in groupby_key_list - for expr_source in groupby_key["exprSources"] - ] - return contexts - - def _compose_contexts_of_sorting_type(sortings: list[dict]): - return [ - f'{expr_source["sourceDataset"]}.{expr_source["expression"]}' - for sorting in sortings - for expr_source in sorting["exprSources"] - ] - - def _get_contexts_from_sql_analysis_results(sql_analysis_results: list[dict]): - contexts = [] - for result in sql_analysis_results: - if "selectItems" in result: - contexts += _compose_contexts_of_select_type(result["selectItems"]) - if "filter" in result: - contexts += _compose_contexts_of_filter_type(result["filter"]) - if "groupByKeys" in result: - contexts += _compose_contexts_of_groupby_type(result["groupByKeys"]) - if "sortings" in result: - contexts += _compose_contexts_of_sorting_type(result["sortings"]) - - # print( - # f'SQL ANALYSIS RESULTS: {orjson.dumps(sql_analysis_results, option=orjson.OPT_INDENT_2).decode("utf-8")}' - # ) - # print(f"CONTEXTS: {sorted(set(contexts))}") - # print("\n\n") - - return sorted(set(contexts)) - + api_endpoint: str = WREN_ENGINE_ENDPOINT, + timeout: float = TIMEOUT_SECONDS, +) -> list[list[str]]: async with aiohttp.ClientSession(): tasks = [] for sql in sqls: - task = asyncio.ensure_future(get_sql_analysis(sql, mdl_json)) + task = asyncio.ensure_future( + get_contexts_from_sql( + sql, + mdl_json, + api_endpoint, + timeout, + ) + ) tasks.append(task) results = await asyncio.gather(*tasks) - return [_get_contexts_from_sql_analysis_results(result) for result in results] + return results async def get_question_sql_pairs( @@ -256,7 +185,9 @@ async def get_question_sql_pairs( ) sqls = [question_sql_pair["sql"] for question_sql_pair in question_sql_pairs] contexts = await get_contexts_from_sqls(sqls, mdl_json) - sqls_data = await get_data_from_wren_engine(sqls, data_source) + sqls_data = await get_data_from_wren_engine_with_sqls( + sqls, data_source, mdl_json, connection_info + ) return [ {**quesiton_sql_pair, "context": context, "data": sql_data} for quesiton_sql_pair, context, sql_data in zip( @@ -276,40 +207,30 @@ def prettify_sql(sql: str) -> str: ) -async def get_data_from_wren_engine( +async def get_data_from_wren_engine_with_sqls( sqls: List[str], data_source: str, + mdl_json: dict, + connection_info: dict, api_endpoint: str = WREN_IBIS_ENDPOINT, - data_sources: list[str] = DATA_SOURCES, timeout: float = TIMEOUT_SECONDS, ) -> List[dict]: - assert data_source in data_sources, f"Invalid data source: {data_source}" - - async def _get_data(sql: str, data_source: str): - async with aiohttp.request( - "POST", - f"{api_endpoint}/v2/connector/{data_source}/query", - json={ - "sql": add_quotes(sql), - "manifestStr": base64.b64encode( - orjson.dumps(st.session_state["mdl_json"]) - ).decode(), - "connectionInfo": st.session_state["connection_info"], - }, - timeout=aiohttp.ClientTimeout(total=timeout), - ) as response: - if response.status != 200: - return {"data": [], "columns": []} - - data = await response.json() - column_names = [f"{i}_{col}" for i, col in enumerate(data["columns"])] - - return {"data": data["data"], "columns": column_names} + assert data_source in DATA_SOURCES, f"Invalid data source: {data_source}" async with aiohttp.ClientSession(): tasks = [] for sql in sqls: - task = asyncio.ensure_future(_get_data(sql, data_source)) + task = asyncio.ensure_future( + get_data_from_wren_engine( + sql, + data_source, + mdl_json, + connection_info, + api_endpoint, + timeout, + limit=50, + ) + ) tasks.append(task) return await asyncio.gather(*tasks) diff --git a/wren-ai-service/eval/utils.py b/wren-ai-service/eval/utils.py new file mode 100644 index 000000000..34bf9bc17 --- /dev/null +++ b/wren-ai-service/eval/utils.py @@ -0,0 +1,121 @@ +import base64 +from typing import List, Optional + +import aiohttp +import orjson +import sqlglot + + +def add_quotes(sql: str) -> str: + return sqlglot.transpile(sql, read="trino", identify=True)[0] + + +async def get_data_from_wren_engine( + sql: str, + data_source: str, + mdl_json: dict, + connection_info: dict, + api_endpoint: str, + timeout: float, + limit: Optional[int] = None, +): + url = f"{api_endpoint}/v2/connector/{data_source}/query" + if limit is not None: + url += f"?limit={limit}" + + async with aiohttp.request( + "POST", + url, + json={ + "sql": add_quotes(sql), + "manifestStr": base64.b64encode(orjson.dumps(mdl_json)).decode(), + "connectionInfo": connection_info, + }, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as response: + if response.status != 200: + return {"data": [], "columns": []} + + data = await response.json() + column_names = [f"{i}_{col}" for i, col in enumerate(data["columns"])] + + return {"data": data["data"], "columns": column_names} + + +async def get_contexts_from_sql( + sql: str, + mdl_json: dict, + api_endpoint: str, + timeout: float, +) -> list[str]: + def _get_contexts_from_sql_analysis_results(sql_analysis_results: list[dict]): + def _compose_contexts_of_select_type(select_items: list[dict]): + return [ + f'{expr_source['sourceDataset']}.{expr_source['expression']}' + for select_item in select_items + for expr_source in select_item["exprSources"] + ] + + def _compose_contexts_of_filter_type(filter: dict): + contexts = [] + if filter["type"] == "EXPR": + contexts += [ + f'{expr_source["sourceDataset"]}.{expr_source["expression"]}' + for expr_source in filter["exprSources"] + ] + elif filter["type"] in ("AND", "OR"): + contexts += _compose_contexts_of_filter_type(filter["left"]) + contexts += _compose_contexts_of_filter_type(filter["right"]) + + return contexts + + def _compose_contexts_of_groupby_type(groupby_keys: list[list[dict]]): + contexts = [] + for groupby_key_list in groupby_keys: + contexts += [ + f'{expr_source["sourceDataset"]}.{expr_source["expression"]}' + for groupby_key in groupby_key_list + for expr_source in groupby_key["exprSources"] + ] + return contexts + + def _compose_contexts_of_sorting_type(sortings: list[dict]): + return [ + f'{expr_source["sourceDataset"]}.{expr_source["expression"]}' + for sorting in sortings + for expr_source in sorting["exprSources"] + ] + + contexts = [] + for result in sql_analysis_results: + if "selectItems" in result: + contexts += _compose_contexts_of_select_type(result["selectItems"]) + if "filter" in result: + contexts += _compose_contexts_of_filter_type(result["filter"]) + if "groupByKeys" in result: + contexts += _compose_contexts_of_groupby_type(result["groupByKeys"]) + if "sortings" in result: + contexts += _compose_contexts_of_sorting_type(result["sortings"]) + + return sorted(set(contexts)) + + async def _get_sql_analysis( + sql: str, + mdl_json: dict, + api_endpoint: str, + timeout: float, + ) -> List[dict]: + sql = sql.rstrip(";") if sql.endswith(";") else sql + async with aiohttp.request( + "GET", + f"{api_endpoint}/v1/analysis/sql", + json={ + "sql": add_quotes(sql), + "manifest": mdl_json, + }, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as response: + return await response.json() + + contexts = await _get_sql_analysis(sql, mdl_json, api_endpoint, timeout) + return _get_contexts_from_sql_analysis_results(contexts)