Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/ai service/dspy #902

Draft
wants to merge 22 commits into
base: feat/ai-service/dspy
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9b3e97c
#692 integrate dspy into sql generation
xuayan-nokia Oct 23, 2024
2f5c9e4
merge main and solve conflict and add prompt_optimizer.py
xuayan-nokia Oct 23, 2024
91cb9ea
CompletionTokensDetails Type is not JSON serializable how to fix it
xuayan-nokia Oct 23, 2024
f7bdd76
Merge branch 'main' into chore/ai-service/dspy_integrate
xuayan-nokia Oct 30, 2024
27b8682
use wrenAI metrics
xuayan-nokia Oct 30, 2024
6669ba1
use predict in sql_generate instead of evaluate
xuayan-nokia Oct 30, 2024
04c06ee
Merge branch 'main' into chore/ai-service/dspy_integrate
tedyyan Nov 1, 2024
afa0ad0
Merge branch 'main' into chore/ai-service/dspy_integrate
cyyeh Nov 2, 2024
e2247a3
fixed two bugs
xuayan-nokia Nov 2, 2024
5666a10
Merge branch 'chore/ai-service/dspy_integrate' of github.com:tedyyan/…
xuayan-nokia Nov 2, 2024
1bc791a
remove import CompletionTokensDetails
xuayan-nokia Nov 2, 2024
197fa58
Merge branch 'main' into chore/ai-service/dspy_integrate
tedyyan Nov 5, 2024
7191998
Merge branch 'main' into chore/ai-service/dspy_integrate
cyyeh Nov 8, 2024
d0426a2
regenerate poetry.lock
xuayan-nokia Nov 8, 2024
f81b1d8
put dspy behind a flag
xuayan-nokia Nov 8, 2024
119a26f
add README
xuayan-nokia Nov 9, 2024
237e79c
:Merge branch 'chore/ai-service/dspy_README' into feat/ai-service/dspy
xuayan-nokia Nov 12, 2024
f66fc81
merge back the dspy sql generation
xuayan-nokia Nov 12, 2024
e06dd8c
updated README
xuayan-nokia Nov 12, 2024
0d96467
updated README 2
xuayan-nokia Nov 12, 2024
ff61bed
updated README 3
xuayan-nokia Nov 12, 2024
fd9915d
move dspy out of eval
xuayan-nokia Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions wren-ai-service/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,65 @@ The evaluation results will be presented on Langfuse as follows:

![shallow_trace_example](../docs/imgs/shallow_trace_example.png)


## How to use DSPy in Wren AI
### Step 1: Generate optimized DSPy module

1. Prepare a predict result and training dataset
Generate a predict dataset without dspy. It's used to initialized evaluation pipeline (Metrics). Refer to https://github.com/Canner/WrenAI/blob/main/wren-ai-service/eval/README.md#eval-dataset-preparationif-using-spider-10-dataset

```
just predict <evaluation-dataset>
```
The output is a predict result. such as `prediction_eval_ask_9df57d69-250c-4a10-b6a5-6595509fed6b_2024_10_23_132136.toml`

2. Train an DSPy module.
Using above predict result and training dataset to train a DSPy module.
```
wren-ai-service/eval/dspy_modules/prompt_optimizer.py --training-dataset spider_car_1_eval_dataset.toml --file prediction_eval_ask_9df57d69-250c-4a10-b6a5-6595509fed6b_2024_10_23_132136.toml
```

output: `eval/optimized/AskGenerationV1_optimized_2024_10_21_181426.json` This is the trained DSPy module

### Step 2: Use the optimized module in pipeline

1. set an environment variable `DSPY_OPTIMAZED_MODEL` which is the trained dspy module above step

```
export DSPY_OPTIMAZED_MODEL=eval/optimized/AskGenerationV1_optimized_2024_10_21_181426.json
```

2. start predict pipeline and get the predicted result

```
just predict eval/dataset/spider_car_1_eval_dataset.toml
```

The output is genereated by DSPy

```
outputs/predictions/prediction_eval_ask_f5103405-09b2-448c-829d-cedd3c3b12d0_2024_10_22_184950.toml

```

### Step 3: (Optional)

1. Evaluate the DSPy prodiction result

```
just eval prediction_eval_ask_f5103405-09b2-448c-829d-cedd3c3b12d0_2024_10_22_184950.toml

```

2. Compare the two results with DSPy and without DSPy

![image](https://github.com/user-attachments/assets/34ee0c25-dcdc-45b7-8cc0-cb2fe55211af)


Notes:
wren-ai-service/eval/dspy_modules/prompt_optimizer.py can be improved by incorporating additional training examples or use other modules in dspy


## Terms

This section describes the terms used in the evaluation framework:
Expand Down
6 changes: 3 additions & 3 deletions wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,16 @@ async def _task(result: Dict[str, str]):
valid_generation_results.append(
{
"sql": quoted_sql,
"correlation_id": addition.get("correlation_id", ""),
"correlation_id": addition.get("correlation_id", "") if isinstance(addition, dict) else addition
}
)
else:
invalid_generation_results.append(
{
"sql": quoted_sql,
"type": "DRY_RUN",
"error": addition.get("error_message", ""),
"correlation_id": addition.get("correlation_id", ""),
"error": addition.get("error_message", "") if isinstance(addition, dict) else addition,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If using DSPy, the addition is a string instead of a dict

"correlation_id": addition.get("correlation_id", "") if isinstance(addition, dict) else addition
}
)
else:
Expand Down
38 changes: 33 additions & 5 deletions wren-ai-service/src/pipelines/generation/sql_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
from src.utils import async_timer, timer
from src.web.v1.services.ask import AskConfigurations

import os
import dspy

from eval.dspy_modules.ask_generation import AskGenerationV1
from eval.dspy_modules.prompt_optimizer import configure_llm_provider

logger = logging.getLogger("wren-ai-service")


Expand Down Expand Up @@ -87,6 +93,7 @@ def prompt(
prompt_builder: PromptBuilder,
configurations: AskConfigurations | None = None,
samples: List[Dict] | None = None,
dspy_module: dspy.Module = None
) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"documents: {documents}")
Expand All @@ -97,6 +104,16 @@ def prompt(
if samples:
logger.debug(f"samples: {samples}")

if dspy_module:
# use dspy to predict, the input is question and context
context = []
dspy_inputs = {}
for doc in documents:
context.append(str(doc))
dspy_inputs['context'] = context
dspy_inputs['question'] = query
return dspy_inputs

return prompt_builder.run(
query=query,
documents=documents,
Expand All @@ -110,10 +127,12 @@ def prompt(

@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql(
prompt: dict,
generator: Any,
) -> dict:
async def generate_sql(prompt: dict, generator: Any, dspy_module: dspy.Module) -> dict:
if dspy_module:
# use dspy to predict, the input is question and context
prediction = dspy_module(question=prompt["question"].as_string(), context=" ".join(prompt["context"]))
return {"replies":[prediction.answer] }

logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))

Expand Down Expand Up @@ -150,14 +169,22 @@ class GenerationResults(BaseModel):
}
}


class SQLGeneration(BasicPipeline):
def __init__(
self,
llm_provider: LLMProvider,
engine: Engine,
**kwargs,
):
self.dspy_module = None
optimized_path = os.getenv("DSPY_OPTIMAZED_MODEL", "")
if optimized_path:
# use dspy to evaluate
configure_llm_provider(
os.getenv("GENERATION_MODEL"), os.getenv("LLM_OPENAI_API_KEY")
)
self.dspy_module = AskGenerationV1()
self.dspy_module.load(optimized_path)
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_generation_system_prompt,
Expand All @@ -167,6 +194,7 @@ def __init__(
template=sql_generation_user_prompt_template
),
"post_processor": SQLGenPostProcessor(engine=engine),
"dspy_module": self.dspy_module
}

self._configs = {
Expand Down
Loading