Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Jul 10, 2024
1 parent c5111a8 commit ed2af7c
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
}
]
}
}
{
} | {
"relation": [
{
"type": "INNER_JOIN" | "LEFT_JOIN" | "RIGHT_JOIN" | "FULL_JOIN" | "CROSS_JOIN" | "IMPLICIT_JOIN"
Expand All @@ -41,18 +39,11 @@
"tableName": "<expression_string>"
}
]
}
{
} | {
"filter": <expression_string>
}
{
} | {
"groupByKeys": [<expression_string>, ...]
}
{
} | {
"sortings": [<expression_string>, ...]
}
Expand Down
127 changes: 102 additions & 25 deletions wren-ai-service/src/pipelines/sql_explanation/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,33 @@
"""


def _compose_sql_expression_of_filter_type(filter_analysis: Dict) -> str:
def _compose_sql_expression_of_filter_type(filter_analysis: Dict) -> Dict:
if filter_analysis["type"] == "EXPR":
return filter_analysis["node"]
return {
"values": filter_analysis["node"],
"id": filter_analysis.get("id", ""),
}
elif filter_analysis["type"] in ("AND", "OR"):
left_expr = _compose_sql_expression_of_filter_type(filter_analysis["left"])
right_expr = _compose_sql_expression_of_filter_type(filter_analysis["right"])
return f"{left_expr} {filter_analysis['type']} {right_expr}"
return {
"values": f"{left_expr} {filter_analysis['type']} {right_expr}",
"id": filter_analysis.get("id", ""),
}

return ""
return {"values": "", "id": ""}


def _compose_sql_expression_of_groupby_type(
groupby_keys: List[List[dict]],
) -> List[str]:
return [
", ".join([expression["expression"] for expression in groupby_key])
{
"values": ", ".join(
[expression["expression"] for expression in groupby_key]
),
"id": "",
}
for groupby_key in groupby_keys
]

Expand All @@ -67,13 +78,30 @@ def _collect_relations(relation, result, top_level: bool = True):
return

if relation["type"] == "TABLE" and top_level:
result.append(relation)
result.append(
{
"values": {
"type": relation["type"],
"tableName": relation["tableName"],
},
"id": relation.get("id", ""),
}
)
elif relation["type"].endswith("_JOIN"):
result.append(
{
"type": relation["type"],
"criteria": relation["criteria"],
"exprSources": relation["exprSources"],
"values": {
"type": relation["type"],
"criteria": relation["criteria"],
"exprSources": [
{
"expression": expr_source["expression"],
"sourceDataset": expr_source["sourceDataset"],
}
for expr_source in relation["exprSources"]
],
},
"id": relation.get("id", ""),
}
)
_collect_relations(relation["left"], result, top_level=False)
Expand All @@ -96,21 +124,36 @@ def _compose_sql_expression_of_select_type(select_items: List[Dict]) -> Dict:
or select_item["properties"]["includeMathematicalOperation"] == "true"
):
result["withFunctionCallOrMathematicalOperation"].append(
{"alias": select_item["alias"], "expression": select_item["expression"]}
{
"values": {
"alias": select_item["alias"],
"expression": select_item["expression"],
},
"id": select_item.get("id", ""),
}
)
else:
result["withoutFunctionCallOrMathematicalOperation"].append(
{
"alias": select_item["alias"],
"expression": select_item["expression"],
"values": {
"alias": select_item["alias"],
"expression": select_item["expression"],
},
"id": select_item.get("id", ""),
}
)

return result


def _compose_sql_expression_of_sortings_type(sortings: List[Dict]) -> List[str]:
return [f'{sorting["expression"]} {sorting["ordering"]}' for sorting in sortings]
return [
{
"values": f'{sorting["expression"]} {sorting["ordering"]}',
"id": sorting.get("id", ""),
}
for sorting in sortings
]


def _extract_to_str(data):
Expand Down Expand Up @@ -221,9 +264,12 @@ def run(
{
"type": "filter",
"payload": {
"id": preprocessed_sql_analysis_results["filter"][
"id"
],
"expression": preprocessed_sql_analysis_results[
"filter"
],
]["values"],
"explanation": _extract_to_str(
sql_explanation_results["filter"]
),
Expand All @@ -245,7 +291,8 @@ def run(
{
"type": "groupByKeys",
"payload": {
"expression": groupby_key,
"id": "",
"expression": groupby_key["values"],
"explanation": _extract_to_str(sql_explanation),
},
}
Expand All @@ -265,7 +312,8 @@ def run(
{
"type": "relation",
"payload": {
**relation,
"id": relation["id"],
**relation["values"],
"explanation": _extract_to_str(sql_explanation),
},
}
Expand All @@ -277,15 +325,15 @@ def run(
sql_analysis_result_for_select_items = [
{
**select_item,
"type": "withFunctionCallOrMathematicalOperation",
"isFunctionCallOrMathematicalOperation": True,
}
for select_item in preprocessed_sql_analysis_results[
"selectItems"
]["withFunctionCallOrMathematicalOperation"]
] + [
{
**select_item,
"type": "withoutFunctionCallOrMathematicalOperation",
"isFunctionCallOrMathematicalOperation": False,
}
for select_item in preprocessed_sql_analysis_results[
"selectItems"
Expand All @@ -303,7 +351,8 @@ def run(
{
"type": "selectItems",
"payload": {
**select_item,
"id": select_item["id"],
**select_item["values"],
"explanation": _extract_to_str(sql_explanation),
},
}
Expand All @@ -323,18 +372,19 @@ def run(
{
"type": "sortings",
"payload": {
"expression": sorting,
"id": sorting["id"],
"expression": sorting["values"],
"explanation": _extract_to_str(sql_explanation),
},
}
)
except Exception as e:
logger.exception(f"Error in GenerationPostProcessor: {e}")

print(
f"PREPROCESSED_SQL_ANALYSIS_RESULTS: {orjson.dumps(preprocessed_sql_analysis_results, option=orjson.OPT_INDENT_2).decode()}"
)
print(f"RESULTS: {orjson.dumps(results, option=orjson.OPT_INDENT_2).decode()}")
# print(
# f"PREPROCESSED_SQL_ANALYSIS_RESULTS: {orjson.dumps(preprocessed_sql_analysis_results, option=orjson.OPT_INDENT_2).decode()}"
# )
# print(f"RESULTS: {orjson.dumps(results, option=orjson.OPT_INDENT_2).decode()}")

return {"results": results}

Expand Down Expand Up @@ -371,7 +421,34 @@ def prompts(
]:
for key, value in preprocessed_sql_analysis_result.items():
if value:
preprocessed_sql_analysis_results_with_values.append({key: value})
if key != "selectItems":
if isinstance(value, list):
preprocessed_sql_analysis_results_with_values.append(
{key: [v["values"] for v in value]}
)
else:
preprocessed_sql_analysis_results_with_values.append(
{key: value["values"]}
)
else:
preprocessed_sql_analysis_results_with_values.append(
{
key: {
"withFunctionCallOrMathematicalOperation": [
v["values"]
for v in value[
"withFunctionCallOrMathematicalOperation"
]
],
"withoutFunctionCallOrMathematicalOperation": [
v["values"]
for v in value[
"withoutFunctionCallOrMathematicalOperation"
]
],
}
}
)

logger.debug(
f"preprocessed_sql_analysis_results_with_values: {orjson.dumps(preprocessed_sql_analysis_results_with_values, option=orjson.OPT_INDENT_2).decode()}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
sql_regeneration_user_prompt_template = """
inputs: {{ results }}
Think step by step
Let's think step by step.
"""


Expand Down

0 comments on commit ed2af7c

Please sign in to comment.