From 42edc8e242148bcfe82a29b9d5e1b5f8aa9c67d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 30 Aug 2024 18:42:08 +0200 Subject: [PATCH] add mipro optimizer --- extra/prompt_tuning/README.md | 11 +++++++++-- extra/prompt_tuning/config/optimizer/mipro.yaml | 9 +++++++++ extra/prompt_tuning/tuning/metrics/iql.py | 6 +++--- 3 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 extra/prompt_tuning/config/optimizer/mipro.yaml diff --git a/extra/prompt_tuning/README.md b/extra/prompt_tuning/README.md index 93e6afb2..54ca81cb 100644 --- a/extra/prompt_tuning/README.md +++ b/extra/prompt_tuning/README.md @@ -2,8 +2,8 @@ This folder contains scripts for prompt tuning and evaluation. Prompts (programs) used in dbally: -- `FILTERING_ASSESSOR` - assesses whether a question requires filtering. -- `AGGREGATION_ASSESSOR` - assesses whether a question requires aggregation. +- `FilteringAssessor` - assesses whether a question requires filtering. +- `AggregationAssessor` - assesses whether a question requires aggregation. All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.io/) dataset. For now, one configuration is available to run the suite against the `superhero` database. @@ -17,6 +17,12 @@ Tune `filtering-assessor` prompt on base signature using [COPRO](https://dspy-do python train.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict ``` +Change optimizer to [MIPRO](https://dspy-docs.vercel.app/docs/cheatsheet#mipro): + +```bash +python train.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict optimizer=mipro +``` + Train multiple prompts: ```bash @@ -30,6 +36,7 @@ Tweak optimizer params to get different results: ```bash python train.py \ + optimizer=copro \ optimizer.params.breadth=2 \ optimizer.params.depth=3 \ optimizer.params.init_temperature=1.0 diff --git a/extra/prompt_tuning/config/optimizer/mipro.yaml b/extra/prompt_tuning/config/optimizer/mipro.yaml new file mode 100644 index 00000000..edf7e139 --- /dev/null +++ b/extra/prompt_tuning/config/optimizer/mipro.yaml @@ -0,0 +1,9 @@ +name: MIPRO +params: + num_candidates: 3 + init_temperature: 1.4 + +compile: + max_bootstrapped_demos: 3 + max_labeled_demos: 0 + num_trials: 10 diff --git a/extra/prompt_tuning/tuning/metrics/iql.py b/extra/prompt_tuning/tuning/metrics/iql.py index 3340f929..b2ad6a42 100644 --- a/extra/prompt_tuning/tuning/metrics/iql.py +++ b/extra/prompt_tuning/tuning/metrics/iql.py @@ -1,9 +1,9 @@ -from typing import Dict +from typing import Dict, List, Optional from dspy import Prediction -def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool: +def filtering_assess_acc(gold: Dict, pred: Prediction, _trace: Optional[List] = None) -> bool: """ IQL filtering decision metric. @@ -19,7 +19,7 @@ def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool: ) -def aggregation_assess_acc(gold: Dict, pred: Prediction) -> bool: +def aggregation_assess_acc(gold: Dict, pred: Prediction, _trace: Optional[List] = None) -> bool: """ IQL aggregation decision metric.