Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Jul 15, 2024
1 parent fc0119e commit 97533eb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 50 deletions.
4 changes: 2 additions & 2 deletions wren-ai-service/eval/data_curation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from streamlit_tags import st_tags
from utils import (
DATA_SOURCES,
get_contexts_from_sqls_v2,
get_contexts_from_sqls,
get_eval_dataset_in_toml_string,
get_llm_client,
get_question_sql_pairs,
Expand Down Expand Up @@ -94,7 +94,7 @@ def on_change_sql(i: int, key: str):

valid, error = asyncio.run(is_sql_valid(sql))
if valid:
new_context = asyncio.run(get_contexts_from_sqls_v2([sql]))
new_context = asyncio.run(get_contexts_from_sqls([sql]))
if i != -1:
st.session_state["llm_question_sql_pairs"][i]["sql"] = sql
st.session_state["llm_question_sql_pairs"][i]["is_valid"] = valid
Expand Down
54 changes: 6 additions & 48 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ async def get_sql_analysis(
return await response.json()


async def get_contexts_from_sqls_v2(
async def get_contexts_from_sqls(
sqls: list[str],
) -> list[str]:
def _compose_contexts_of_select_type(select_items: list[dict]):
Expand Down Expand Up @@ -281,6 +281,9 @@ def _get_contexts_from_sql_analysis_results(sql_analysis_results: list[dict]):
if "sortings" in result:
contexts += _compose_contexts_of_sorting_type(result["sortings"])

print(
f'SQL ANALYSIS RESULTS: {orjson.dumps(sql_analysis_results, option=orjson.OPT_INDENT_2).decode("utf-8")}'
)
print(f"CONTEXTS: {sorted(set(contexts))}")
return sorted(set(contexts))

Expand Down Expand Up @@ -335,7 +338,7 @@ async def get_question_sql_pairs(

try:
response = await llm_client.chat.completions.create(
model=os.getenv("OPENAI_GENERATION_MODEL", "gpt-3.5-turbo"),
model=os.getenv("GENERATION_MODEL", "gpt-3.5-turbo"),
messages=messages,
response_format={"type": "json_object"},
max_tokens=4096,
Expand All @@ -345,7 +348,7 @@ async def get_question_sql_pairs(
results = orjson.loads(response.choices[0].message.content)["results"]
question_sql_pairs = await get_validated_question_sql_pairs(results)
sqls = [question_sql_pair["sql"] for question_sql_pair in question_sql_pairs]
contexts = await get_contexts_from_sqls_v2(sqls)
contexts = await get_contexts_from_sqls(sqls)
return [
{**quesiton_sql_pair, "context": context}
for quesiton_sql_pair, context in zip(question_sql_pairs, contexts)
Expand All @@ -355,51 +358,6 @@ async def get_question_sql_pairs(
return []


def show_er_diagram(models: List[dict], relationships: List[dict]):
# Start of the Graphviz syntax
graphviz = "digraph ERD {\n"
graphviz += ' graph [pad="0.5", nodesep="0.5", ranksep="2"];\n'
graphviz += " node [shape=plain]\n"
graphviz += " rankdir=LR;\n\n"

# Function to format the label for Graphviz
def format_label(name, columns):
label = f'<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0"><TR><TD><B>{name}</B></TD></TR>'
for column in columns:
label += f'<TR><TD>{column["name"]} : {column["type"]}</TD></TR>'
label += "</TABLE>>"
return label

# Add models (entities) to the Graphviz syntax
for model in models:
graphviz += f' {model["name"]} [label={format_label(model["name"], model["columns"])}];\n'

graphviz += "\n"

# Extract columns involved in each relationship
def extract_columns(condition):
# This regular expression should match the condition format and extract column names
matches = re.findall(r"(\w+)\.(\w+) = (\w+)\.(\w+)", condition)
if matches:
return matches[0][1], matches[0][3] # Returns (from_column, to_column)
return "", ""

# Add relationships to the Graphviz syntax
for relationship in relationships:
from_model, to_model = relationship["models"]
from_column, to_column = extract_columns(relationship["condition"])
label = (
f'{relationship["name"]}\\n({from_column} to {to_column}) ({relationship['joinType']})'
if from_column and to_column
else relationship["name"]
)
graphviz += f' {from_model} -> {to_model} [label="{label}"];\n'

graphviz += "}"

st.graphviz_chart(graphviz)


def prettify_sql(sql: str) -> str:
return sqlparse.format(
sql,
Expand Down

0 comments on commit 97533eb

Please sign in to comment.