diff --git a/wren-ai-service/demo/app.py b/wren-ai-service/demo/app.py index 73656469a4..82a701332d 100644 --- a/wren-ai-service/demo/app.py +++ b/wren-ai-service/demo/app.py @@ -49,18 +49,14 @@ st.session_state["preview_sql"] = None if "query_history" not in st.session_state: st.session_state["query_history"] = None -if "sql_explanation_button_index" not in st.session_state: - st.session_state["sql_explanation_button_index"] = None if "sql_explanation_question" not in st.session_state: st.session_state["sql_explanation_question"] = None -if "sql_explanation_sql" not in st.session_state: - st.session_state["sql_explanation_sql"] = None -if "sql_explanation_full_sql" not in st.session_state: - st.session_state["sql_explanation_full_sql"] = None -if "sql_explanation_sql_summary" not in st.session_state: - st.session_state["sql_explanation_sql_summary"] = None -if "sql_explanation_sql_analysis" not in st.session_state: - st.session_state["sql_explanation_sql_analysis"] = None +if "sql_explanation_steps_with_analysis" not in st.session_state: + st.session_state["sql_explanation_steps_with_analysis"] = None +if "sql_analysis_results" not in st.session_state: + st.session_state["sql_analysis_results"] = None +if "sql_explanation_results" not in st.session_state: + st.session_state["sql_explanation_results"] = None def onchange_demo_dataset(): diff --git a/wren-ai-service/demo/utils.py b/wren-ai-service/demo/utils.py index 162c1873c2..33412f54a0 100644 --- a/wren-ai-service/demo/utils.py +++ b/wren-ai-service/demo/utils.py @@ -332,12 +332,12 @@ def show_asks_details_results(query: str): args=[query, sqls, summaries], ) - if st.session_state["sql_explanation_sql"] is not None: - sql_explanation_results = sql_explanation() - - if sql_explanation_results: - st.markdown("### SQL Explanation Results") - st.json(sql_explanation_results) + if st.session_state["sql_analysis_results"]: + st.markdown("### SQL Analysis Results") + st.json(st.session_state["sql_analysis_results"]) + if st.session_state["sql_explanation_results"]: + st.markdown("### SQL Explanation Results") + st.json(st.session_state["sql_explanation_results"]) def on_click_preview_data_button(index: int, full_sqls: List[str]): @@ -348,6 +348,7 @@ def on_click_preview_data_button(index: int, full_sqls: List[str]): def get_sql_analysis_results(sqls: List[str]): results = [] for sql in sqls: + print(f"SQL: {sql}") response = requests.get( f"{WREN_ENGINE_API_URL}/v1/analysis/sql", json={ @@ -367,10 +368,10 @@ def on_click_sql_explanation_button( sqls: List[str], summaries: List[str], ): - st.session_state["sql_explanation_sql"] = True sql_analysis_results = get_sql_analysis_results(sqls) st.session_state["sql_explanation_question"] = question + st.session_state["sql_analysis_results"] = sql_analysis_results st.session_state["sql_explanation_steps_with_analysis"] = [ {"sql": sql, "summary": summary, "sql_analysis_results": sql_analysis_results} for sql, summary, sql_analysis_results in zip( @@ -378,6 +379,8 @@ def on_click_sql_explanation_button( ) ] + st.session_state["sql_explanation_results"] = sql_explanation() + # ai service api related def generate_mdl_metadata(mdl_model_json: dict): diff --git a/wren-ai-service/src/pipelines/sql_explanation/generation.py b/wren-ai-service/src/pipelines/sql_explanation/generation.py index 1c97367872..008699c55c 100644 --- a/wren-ai-service/src/pipelines/sql_explanation/generation.py +++ b/wren-ai-service/src/pipelines/sql_explanation/generation.py @@ -229,7 +229,7 @@ def preprocess( sql_analysis_results: List[dict], pre_processor: SQLAnalysisPreprocessor ) -> List[dict]: logger.debug(f"sql_analysis_results: {sql_analysis_results}") - return pre_processor.run(sql_analysis_results)["preprocessed_sql_analysis_results"] + return pre_processor.run(sql_analysis_results) @timer @@ -243,13 +243,13 @@ def prompt( ) -> dict: logger.debug(f"question: {question}") logger.debug(f"sql: {sql}") - logger.debug(f"preprocessed_sql_analysis_results: {preprocess}") + logger.debug(f"preprocess: {preprocess}") logger.debug(f"sql_summary: {sql_summary}") logger.debug(f"full_sql: {full_sql}") return prompt_builder.run( question=question, sql=sql, - sql_analysis_results=preprocess, + sql_analysis_results=preprocess["preprocessed_sql_analysis_results"], sql_summary=sql_summary, full_sql=full_sql, )