diff --git a/wren-ai-service/eval/dspy_modules/__init__.py b/wren-ai-service/dspy_modules/__init__.py similarity index 100% rename from wren-ai-service/eval/dspy_modules/__init__.py rename to wren-ai-service/dspy_modules/__init__.py diff --git a/wren-ai-service/eval/dspy_modules/ask_generation.py b/wren-ai-service/dspy_modules/ask_generation.py similarity index 100% rename from wren-ai-service/eval/dspy_modules/ask_generation.py rename to wren-ai-service/dspy_modules/ask_generation.py diff --git a/wren-ai-service/eval/dspy_modules/prompt_optimizer.py b/wren-ai-service/dspy_modules/prompt_optimizer.py similarity index 95% rename from wren-ai-service/eval/dspy_modules/prompt_optimizer.py rename to wren-ai-service/dspy_modules/prompt_optimizer.py index 95b699c5c..51e3c4ae4 100644 --- a/wren-ai-service/eval/dspy_modules/prompt_optimizer.py +++ b/wren-ai-service/dspy_modules/prompt_optimizer.py @@ -13,8 +13,12 @@ sys.path.append(f"{Path().parent.resolve()}") import src.utils as utils -from eval.dspy_modules.ask_generation import AskGenerationV1 -from eval.utils import parse_toml +from dspy_modules.ask_generation import AskGenerationV1 +from tomlkit import parse + +def parse_toml(path: str) -> dict[str, any]: + with open(path) as file: + return parse(file.read()) def parse_args() -> Tuple[str]: @@ -62,7 +66,7 @@ def clean_sql(sql: str) -> str: def prepare_dataset(path: str, train_ratio: float = 0.5): - eval_dataset = parse_toml(f"eval/dataset/{path}")["eval_dataset"] + eval_dataset = parse_toml(f"{path}")["eval_dataset"] dspy_dataset = [] for data in eval_dataset: diff --git a/wren-ai-service/eval/README.md b/wren-ai-service/eval/README.md index 082f6b142..0370643fd 100644 --- a/wren-ai-service/eval/README.md +++ b/wren-ai-service/eval/README.md @@ -68,6 +68,65 @@ The evaluation results will be presented on Langfuse as follows: ![shallow_trace_example](../docs/imgs/shallow_trace_example.png) + +## How to use DSPy in Wren AI +### Step 1: Generate optimized DSPy module + +1. Prepare a predict result and training dataset +Generate a predict dataset without dspy. It's used to initialized evaluation pipeline (Metrics). Refer to https://github.com/Canner/WrenAI/blob/main/wren-ai-service/eval/README.md#eval-dataset-preparationif-using-spider-10-dataset + +``` +just predict +``` +The output is a predict result. such as `prediction_eval_ask_9df57d69-250c-4a10-b6a5-6595509fed6b_2024_10_23_132136.toml` + +2. Train an DSPy module. +Using above predict result and training dataset to train a DSPy module. +``` +wren-ai-service/eval/dspy_modules/prompt_optimizer.py --training-dataset spider_car_1_eval_dataset.toml --file prediction_eval_ask_9df57d69-250c-4a10-b6a5-6595509fed6b_2024_10_23_132136.toml +``` + +output: `eval/optimized/AskGenerationV1_optimized_2024_10_21_181426.json` This is the trained DSPy module + +### Step 2: Use the optimized module in pipeline + +1. set an environment variable `DSPY_OPTIMAZED_MODEL` which is the trained dspy module above step + +``` +export DSPY_OPTIMAZED_MODEL=eval/optimized/AskGenerationV1_optimized_2024_10_21_181426.json +``` + +2. start predict pipeline and get the predicted result + +``` +just predict eval/dataset/spider_car_1_eval_dataset.toml +``` + +The output is genereated by DSPy + +``` +outputs/predictions/prediction_eval_ask_f5103405-09b2-448c-829d-cedd3c3b12d0_2024_10_22_184950.toml + +``` + +### Step 3: (Optional) + +1. Evaluate the DSPy prodiction result + +``` +just eval prediction_eval_ask_f5103405-09b2-448c-829d-cedd3c3b12d0_2024_10_22_184950.toml + +``` + +2. Compare the two results with DSPy and without DSPy + +![image](https://github.com/user-attachments/assets/34ee0c25-dcdc-45b7-8cc0-cb2fe55211af) + + +Notes: +wren-ai-service/eval/dspy_modules/prompt_optimizer.py can be improved by incorporating additional training examples or use other modules in dspy + + ## Terms This section describes the terms used in the evaluation framework: diff --git a/wren-ai-service/src/pipelines/common.py b/wren-ai-service/src/pipelines/common.py index 8b91c4e52..cc093c084 100644 --- a/wren-ai-service/src/pipelines/common.py +++ b/wren-ai-service/src/pipelines/common.py @@ -183,8 +183,8 @@ async def _task(result: Dict[str, str]): { "sql": quoted_sql, "type": "DRY_RUN", - "error": addition.get("error_message", ""), - "correlation_id": addition.get("correlation_id", ""), + "error": addition.get("error_message", "") if isinstance(addition, dict) else addition, + "correlation_id": addition.get("correlation_id", "") if isinstance(addition, dict) else addition } ) else: diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index ddf4e559f..192914829 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -23,6 +23,12 @@ from src.utils import async_timer, timer from src.web.v1.services.ask import AskConfigurations +import os +import dspy + +from dspy_modules.ask_generation import AskGenerationV1 +from dspy_modules.prompt_optimizer import configure_llm_provider + logger = logging.getLogger("wren-ai-service") @@ -87,6 +93,7 @@ def prompt( prompt_builder: PromptBuilder, configurations: AskConfigurations | None = None, samples: List[Dict] | None = None, + dspy_module: dspy.Module = None ) -> dict: logger.debug(f"query: {query}") logger.debug(f"documents: {documents}") @@ -97,6 +104,16 @@ def prompt( if samples: logger.debug(f"samples: {samples}") + if dspy_module: + # use dspy to predict, the input is question and context + context = [] + dspy_inputs = {} + for doc in documents: + context.append(str(doc)) + dspy_inputs['context'] = context + dspy_inputs['question'] = query + return dspy_inputs + return prompt_builder.run( query=query, documents=documents, @@ -110,10 +127,12 @@ def prompt( @async_timer @observe(as_type="generation", capture_input=False) -async def generate_sql( - prompt: dict, - generator: Any, -) -> dict: +async def generate_sql(prompt: dict, generator: Any, dspy_module: dspy.Module) -> dict: + if dspy_module: + # use dspy to predict, the input is question and context + prediction = dspy_module(question=prompt["question"].as_string(), context=" ".join(prompt["context"])) + return {"replies":[prediction.answer] } + logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}") return await generator.run(prompt=prompt.get("prompt")) @@ -150,7 +169,6 @@ class GenerationResults(BaseModel): } } - class SQLGeneration(BasicPipeline): def __init__( self, @@ -158,6 +176,15 @@ def __init__( engine: Engine, **kwargs, ): + self.dspy_module = None + optimized_path = os.getenv("DSPY_OPTIMAZED_MODEL", "") + if optimized_path: + # use dspy to evaluate + configure_llm_provider( + os.getenv("GENERATION_MODEL"), os.getenv("LLM_OPENAI_API_KEY") + ) + self.dspy_module = AskGenerationV1() + self.dspy_module.load(optimized_path) self._components = { "generator": llm_provider.get_generator( system_prompt=sql_generation_system_prompt, @@ -167,6 +194,7 @@ def __init__( template=sql_generation_user_prompt_template ), "post_processor": SQLGenPostProcessor(engine=engine), + "dspy_module": self.dspy_module } self._configs = {