Skip to content

Commit

Permalink
add mipro optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Aug 30, 2024
1 parent dde4452 commit 42edc8e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
11 changes: 9 additions & 2 deletions extra/prompt_tuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

This folder contains scripts for prompt tuning and evaluation. Prompts (programs) used in dbally:

- `FILTERING_ASSESSOR` - assesses whether a question requires filtering.
- `AGGREGATION_ASSESSOR` - assesses whether a question requires aggregation.
- `FilteringAssessor` - assesses whether a question requires filtering.
- `AggregationAssessor` - assesses whether a question requires aggregation.

All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.io/) dataset. For now, one configuration is available to run the suite against the `superhero` database.

Expand All @@ -17,6 +17,12 @@ Tune `filtering-assessor` prompt on base signature using [COPRO](https://dspy-do
python train.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict
```

Change optimizer to [MIPRO](https://dspy-docs.vercel.app/docs/cheatsheet#mipro):

```bash
python train.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict optimizer=mipro
```

Train multiple prompts:

```bash
Expand All @@ -30,6 +36,7 @@ Tweak optimizer params to get different results:

```bash
python train.py \
optimizer=copro \
optimizer.params.breadth=2 \
optimizer.params.depth=3 \
optimizer.params.init_temperature=1.0
Expand Down
9 changes: 9 additions & 0 deletions extra/prompt_tuning/config/optimizer/mipro.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: MIPRO
params:
num_candidates: 3
init_temperature: 1.4

compile:
max_bootstrapped_demos: 3
max_labeled_demos: 0
num_trials: 10
6 changes: 3 additions & 3 deletions extra/prompt_tuning/tuning/metrics/iql.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict
from typing import Dict, List, Optional

from dspy import Prediction


def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool:
def filtering_assess_acc(gold: Dict, pred: Prediction, _trace: Optional[List] = None) -> bool:
"""
IQL filtering decision metric.
Expand All @@ -19,7 +19,7 @@ def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool:
)


def aggregation_assess_acc(gold: Dict, pred: Prediction) -> bool:
def aggregation_assess_acc(gold: Dict, pred: Prediction, _trace: Optional[List] = None) -> bool:
"""
IQL aggregation decision metric.
Expand Down

0 comments on commit 42edc8e

Please sign in to comment.