Skip to content

Commit

Permalink
simplify pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Jul 2, 2024
1 parent bd2ae99 commit 8484a63
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 190 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,3 @@
Given the original description and the each step of the SQL query, SQL summary, cte name(some of the steps include regenerated SQL queries, SQL summary from the subtask1),
your job is to regenerate the description considering all steps and regenerate the SQL query considering if regenerated SQL query would affectes original SQL query in subsequent steps.
"""

description_regeneration_system_prompt = """
### TASK ###
Given the steps of the SQL query, SQL summary, cte name and the original description,
your job is to regenerate the original description using less than 30 words considering all steps.
"""
255 changes: 73 additions & 182 deletions wren-ai-service/src/pipelines/sql_regeneration/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
from haystack import component
from haystack.components.builders.prompt_builder import PromptBuilder

from src.core.engine import clean_generation_result
from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.sql_regeneration.components.prompts import (
description_regeneration_system_prompt,
sql_regeneration_system_prompt,
)
from src.utils import async_timer, init_providers, timer
Expand All @@ -26,143 +24,92 @@
sql_regeneration_user_prompt_template = """
### TASK ###
Given each step of the SQL query, SQL summary, cte name and a list of user corrections,
your job is to regenerate the corresponding SQL query, SQL summary and cte given the user corrections.
- Given a list of user corrections, regenerate the corresponding SQL query.
- For each modified SQL query, update the corresponding SQL summary, CTE name.
- If subsequent steps are dependent on the corrected step, make sure to update the SQL query, SQL summary and CTE name in subsequent steps if needed.
- Regenerate the description after correcting all of the steps.
### INPUT STRUCTURE ###
{
"index": <step_index>,
"step": {
"summary": "<original_sql_summary_string>",
"sql": "<original_sql_string>",
"cte_name": "<original_cte_name_string>",
"corrections": [
{
"before": {
"type": "<filter/selectItems/relation/groupByKeys/sortings>",
"value": "<original_value_string>"
},
"after": {
"type": "<sql_expression/nl_expression>",
"value": "<new_value_string>"
}
},...
]
},...
"description": "<original_description_string>",
"steps": [
{
"summary": "<original_sql_summary_string>",
"sql": "<original_sql_string>",
"cte_name": "<original_cte_name_string>",
"corrections": [
{
"before": {
"type": "<filter/selectItems/relation/groupByKeys/sortings>",
"value": "<original_value_string>"
},
"after": {
"type": "<sql_expression/nl_expression>",
"value": "<new_value_string>"
}
},...
]
},...
]
}
### INPUT ###
{% for result in results %}
{{ result }}
{% endfor %}
{{ results }}
### OUTPUT STRUCTURE ###
Generate modified results according to the following in JSON format:
{
"results": [
"description": "<modified_description_string>",
"steps": [
{
"index": <step_index>,
"summary": "<modified_sql_summary_string>",
"sql": "<modified_sql_string>",
"cte_name": "<modified_cte_name_string>"
},
{
"index": <step_index>,
"summary": "<modified_sql_summary_string>",
"sql": "<modified_sql_string>",
"cte_name": "<modified_cte_name_string>"
},
...
"cte_name": "<modified_cte_name_string>",
},...
]
}
Think step by step
"""


description_regeneration_user_prompt_template = """
### INPUT ###
description: {{ description }}
steps:
{% for step in steps %}
{{ step }}
{% endfor %}
### OUTPUT STRUCTURE ###
{
"description": "<modified_description_string>"
}
Generate modified description according to the OUTPUT STRUCTURE in JSON format
Think step by step
"""


@component
class StepsWithUserCorrectionsFilter:
class SQLRegenerationRreprocesser:
@component.output_types(
results=List[Dict[str, Any]],
results=Dict[str, Any],
)
def run(
self, steps: List[SQLExplanationWithUserCorrections]
) -> Dict[str, List[Dict]]:
self,
description: str,
steps: List[SQLExplanationWithUserCorrections],
) -> Dict[str, Any]:
return {
"results": list(
map(
lambda step: {
"index": step["index"],
"step": step["step"].model_dump_json(),
},
filter(
lambda step: step["step"].corrections,
[{"index": i, "step": step} for i, step in enumerate(steps)],
),
)
)
"results": {
"description": description,
"steps": steps,
}
}


@component
class SQLRegenerationByStepPostProcessor:
class SQLRegenerationPostProcessor:
@component.output_types(
description=str,
steps=List[str],
)
def run(
self,
replies: List[str],
original_description: str,
original_steps: List[str],
) -> Dict[str, Any]:
try:
modified_steps = orjson.loads(replies[0]).get("results", [])
new_steps = [
{
"sql": clean_generation_result(step.sql),
"summary": step.summary,
"cte_name": step.cte_name,
}
for step in original_steps
]
if new_steps:
for modified_step in modified_steps:
new_steps[modified_step["index"]] = {
"sql": clean_generation_result(modified_step.get("sql", "")),
"summary": modified_step.get("summary", ""),
"cte_name": modified_step.get("cte_name", ""),
}

return {"description": original_description, "steps": new_steps}
return {"results": orjson.loads(replies[0])}
except Exception as e:
logger.exception(f"Error in SQLRegenerationByStepPostProcessor: {e}")
return {"description": original_description, "steps": original_steps}
logger.exception(f"Error in SQLRegenerationPostProcessor: {e}")
return {"results": None}


@component
Expand Down Expand Up @@ -190,88 +137,46 @@ def run(
## Start of Pipeline
@timer
def preprocess(
description: str,
steps: List[SQLExplanationWithUserCorrections],
steps_with_user_corrections_filter: StepsWithUserCorrectionsFilter,
sql_regeneration_preprocesser: SQLRegenerationRreprocesser,
) -> dict[str, Any]:
logger.debug(f"steps: {steps}")
return steps_with_user_corrections_filter.run(steps)["results"]
logger.debug(f"description: {description}")
return sql_regeneration_preprocesser.run(
description=description,
steps=steps,
)


@timer
def sql_regeneration_by_step_prompt(
def sql_regeneration_prompt(
preprocess: Dict[str, Any],
sql_regeneration_by_step_prompt_builder: PromptBuilder,
sql_regeneration_prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"preprocess: {preprocess}")
return sql_regeneration_by_step_prompt_builder.run(results=preprocess)
return sql_regeneration_prompt_builder.run(results=preprocess["results"])


@async_timer
async def sql_regeneration_by_step_generate(
sql_regeneration_by_step_prompt: dict,
sql_regeneration_by_step_generator: Any,
async def sql_regeneration_generate(
sql_regeneration_prompt: dict,
sql_regeneration_generator: Any,
) -> dict:
logger.debug(f"sql_regeneration_by_step_prompt: {sql_regeneration_by_step_prompt}")
return await sql_regeneration_by_step_generator.run(
prompt=sql_regeneration_by_step_prompt.get("prompt")
logger.debug(f"sql_regeneration_prompt: {sql_regeneration_prompt}")
return await sql_regeneration_generator.run(
prompt=sql_regeneration_prompt.get("prompt")
)


@timer
def sql_regeneration_post_process(
sql_regeneration_by_step_generate: dict,
description: str,
steps: List[SQLExplanationWithUserCorrections],
sql_regeneration_by_step_post_processor: SQLRegenerationByStepPostProcessor,
) -> dict:
logger.debug(
f"sql_regeneration_by_step_generate: {sql_regeneration_by_step_generate}"
)
logger.debug(f"description: {description}")
logger.debug(f"steps: {steps}")
return sql_regeneration_by_step_post_processor.run(
replies=sql_regeneration_by_step_generate.get("replies"),
original_description=description,
original_steps=steps,
)


@timer
def description_regeneration_prompt(
sql_regeneration_post_process: dict,
description_regeneration_prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"sql_regeneration_post_process: {sql_regeneration_post_process}")
return description_regeneration_prompt_builder.run(
description=sql_regeneration_post_process.get("description"),
steps=sql_regeneration_post_process.get("steps"),
)


@async_timer
async def description_regeneration_generate(
description_regeneration_prompt: dict,
description_regeneration_generator: Any,
) -> dict:
logger.debug(f"description_regeneration_prompt: {description_regeneration_prompt}")
return await description_regeneration_generator.run(
prompt=description_regeneration_prompt.get("prompt")
)


@timer
def description_regeneration_post_process(
description_regeneration_generate: dict,
sql_regeneration_post_process: dict,
description_regeneration_post_processor: DescriptionRegenerationPostProcessor,
sql_regeneration_generate: dict,
sql_regeneration_post_processor: SQLRegenerationPostProcessor,
) -> dict:
logger.debug(
f"description_regeneration_generate: {description_regeneration_generate}"
)
logger.debug(f"sql_regeneration_post_process: {sql_regeneration_post_process}")
return description_regeneration_post_processor.run(
replies=description_regeneration_generate.get("replies"),
steps=sql_regeneration_post_process.get("steps"),
logger.debug(f"sql_regeneration_generate: {sql_regeneration_generate}")
return sql_regeneration_post_processor.run(
replies=sql_regeneration_generate.get("replies"),
)


Expand All @@ -283,25 +188,14 @@ def __init__(
self,
llm_provider: LLMProvider,
):
self.steps_with_user_corrections_filter = StepsWithUserCorrectionsFilter()
self.sql_regeneration_by_step_prompt_builder = PromptBuilder(
self.sql_regeneration_preprocesser = SQLRegenerationRreprocesser()
self.sql_regeneration_prompt_builder = PromptBuilder(
template=sql_regeneration_user_prompt_template
)
self.sql_regeneration_by_step_generator = llm_provider.get_generator(
self.sql_regeneration_generator = llm_provider.get_generator(
system_prompt=sql_regeneration_system_prompt
)
self.sql_regeneration_by_step_post_processor = (
SQLRegenerationByStepPostProcessor()
)
self.description_regeneration_prompt_builder = PromptBuilder(
template=description_regeneration_user_prompt_template
)
self.description_regeneration_generator = llm_provider.get_generator(
system_prompt=description_regeneration_system_prompt
)
self.description_regeneration_post_processor = (
DescriptionRegenerationPostProcessor()
)
self.sql_regeneration_post_processor = SQLRegenerationPostProcessor()

super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
Expand All @@ -315,17 +209,14 @@ async def run(
):
logger.info("SQL Regeneration Generation pipeline is running...")
return await self._pipe.execute(
["description_regeneration_post_process"],
["sql_regeneration_post_process"],
inputs={
"description": description,
"steps": steps,
"steps_with_user_corrections_filter": self.steps_with_user_corrections_filter,
"sql_regeneration_by_step_prompt_builder": self.sql_regeneration_by_step_prompt_builder,
"sql_regeneration_by_step_generator": self.sql_regeneration_by_step_generator,
"sql_regeneration_by_step_post_processor": self.sql_regeneration_by_step_post_processor,
"description_regeneration_prompt_builder": self.description_regeneration_prompt_builder,
"description_regeneration_generator": self.description_regeneration_generator,
"description_regeneration_post_processor": self.description_regeneration_post_processor,
"sql_regeneration_preprocesser": self.sql_regeneration_preprocesser,
"sql_regeneration_prompt_builder": self.sql_regeneration_prompt_builder,
"sql_regeneration_generator": self.sql_regeneration_generator,
"sql_regeneration_post_processor": self.sql_regeneration_post_processor,
},
)

Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/web/v1/services/sql_regeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def sql_regeneration(
)

sql_regeneration_result = generation_result[
"description_regeneration_post_process"
"sql_regeneration_post_process"
]["results"]

logger.debug(f"sql regeneration results: {sql_regeneration_result}")
Expand Down

0 comments on commit 8484a63

Please sign in to comment.