From b202ea3614883a07837c4a35b1a03eb7397f74e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Sat, 17 Aug 2024 02:58:06 +0200 Subject: [PATCH 01/13] add aggregation assessor program --- .../aggregation-assessor-baseline.yaml | 2 + .../program/aggregation-assessor-cot.yaml | 2 + extra/prompt_tuning/evaluate.py | 5 +- .../prompt_tuning/tuning/metrics/__init__.py | 4 +- extra/prompt_tuning/tuning/metrics/iql.py | 20 +++++++- extra/prompt_tuning/tuning/programs/iql.py | 48 ++++++++++++++++++- extra/prompt_tuning/tuning/signatures/iql.py | 19 ++++++++ 7 files changed, 94 insertions(+), 6 deletions(-) create mode 100644 extra/prompt_tuning/config/program/aggregation-assessor-baseline.yaml create mode 100644 extra/prompt_tuning/config/program/aggregation-assessor-cot.yaml diff --git a/extra/prompt_tuning/config/program/aggregation-assessor-baseline.yaml b/extra/prompt_tuning/config/program/aggregation-assessor-baseline.yaml new file mode 100644 index 00000000..98d1d68b --- /dev/null +++ b/extra/prompt_tuning/config/program/aggregation-assessor-baseline.yaml @@ -0,0 +1,2 @@ +type: AGGREGATION_ASSESSOR +name: AggregationAssessorBaseline diff --git a/extra/prompt_tuning/config/program/aggregation-assessor-cot.yaml b/extra/prompt_tuning/config/program/aggregation-assessor-cot.yaml new file mode 100644 index 00000000..f6a7be97 --- /dev/null +++ b/extra/prompt_tuning/config/program/aggregation-assessor-cot.yaml @@ -0,0 +1,2 @@ +type: AGGREGATION_ASSESSOR +name: AggregationAssessorCoT diff --git a/extra/prompt_tuning/evaluate.py b/extra/prompt_tuning/evaluate.py index 35bcf2e8..0ae140e4 100644 --- a/extra/prompt_tuning/evaluate.py +++ b/extra/prompt_tuning/evaluate.py @@ -10,7 +10,7 @@ from neptune.utils import stringify_unsupported from omegaconf import DictConfig from tuning.loaders import IQLGenerationDataLoader -from tuning.metrics import filtering_assess_acc +from tuning.metrics import aggregation_assess_acc, filtering_assess_acc from tuning.programs import PROGRAMS from tuning.utils import save, serialize_results @@ -25,14 +25,17 @@ class EvaluationType(Enum): """ FILTERING_ASSESSOR = "FILTERING_ASSESSOR" + AGGREGATION_ASSESSOR = "AGGREGATION_ASSESSOR" EVALUATION_DATALOADERS = { EvaluationType.FILTERING_ASSESSOR.value: IQLGenerationDataLoader, + EvaluationType.AGGREGATION_ASSESSOR.value: IQLGenerationDataLoader, } EVALUATION_METRICS = { EvaluationType.FILTERING_ASSESSOR.value: filtering_assess_acc, + EvaluationType.AGGREGATION_ASSESSOR.value: aggregation_assess_acc, } diff --git a/extra/prompt_tuning/tuning/metrics/__init__.py b/extra/prompt_tuning/tuning/metrics/__init__.py index 56615738..60597b40 100644 --- a/extra/prompt_tuning/tuning/metrics/__init__.py +++ b/extra/prompt_tuning/tuning/metrics/__init__.py @@ -1,3 +1,3 @@ -from .iql import filtering_assess_acc +from .iql import aggregation_assess_acc, filtering_assess_acc -__all__ = ["filtering_assess_acc"] +__all__ = ["aggregation_assess_acc", "filtering_assess_acc"] diff --git a/extra/prompt_tuning/tuning/metrics/iql.py b/extra/prompt_tuning/tuning/metrics/iql.py index d47b0689..3340f929 100644 --- a/extra/prompt_tuning/tuning/metrics/iql.py +++ b/extra/prompt_tuning/tuning/metrics/iql.py @@ -5,15 +5,31 @@ def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool: """ - IQL decision metric. + IQL filtering decision metric. Args: gold: The ground truth data point. pred: The prediction. Returns: - The decision metric. + The filtering decision accuracy. """ return ((gold["iql_filters"] is None and not gold["iql_filters_unsupported"]) and not pred.decision) or ( (gold["iql_filters"] is not None or gold["iql_filters_unsupported"]) and pred.decision ) + + +def aggregation_assess_acc(gold: Dict, pred: Prediction) -> bool: + """ + IQL aggregation decision metric. + + Args: + gold: The ground truth data point. + pred: The prediction. + + Returns: + The aggregation decision accuracy. + """ + return ((gold["iql_aggregation"] is None and not gold["iql_aggregation_unsupported"]) and not pred.decision) or ( + (gold["iql_aggregation"] is not None or gold["iql_aggregation_unsupported"]) and pred.decision + ) diff --git a/extra/prompt_tuning/tuning/programs/iql.py b/extra/prompt_tuning/tuning/programs/iql.py index 2a20da47..a7209672 100644 --- a/extra/prompt_tuning/tuning/programs/iql.py +++ b/extra/prompt_tuning/tuning/programs/iql.py @@ -1,6 +1,6 @@ from dspy import ChainOfThought, Module, Predict, Prediction -from ..signatures.iql import CheckQuestionFiltering +from ..signatures.iql import CheckQuestionAggregation, CheckQuestionFiltering class FilteringAssessorBaseline(Module): @@ -47,3 +47,49 @@ def forward(self, question: str) -> Prediction: """ decision = self.decide(question=question).decision return Prediction(decision=decision.lower() == "true") + + +class AggregationAssessorBaseline(Module): + """ + Program that assesses whether a question requires aggregation. + """ + + def __init__(self) -> None: + super().__init__() + self.decide = Predict(CheckQuestionAggregation) + + def forward(self, question: str) -> Prediction: + """ + Assess whether a question requires aggregation. + + Args: + question: The question to assess. + + Returns: + The prediction. + """ + decision = self.decide(question=question).decision + return Prediction(decision=decision.lower() == "true") + + +class AggregationAssessorCoT(Module): + """ + Program that assesses whether a question requires aggregation. + """ + + def __init__(self) -> None: + super().__init__() + self.decide = ChainOfThought(CheckQuestionAggregation) + + def forward(self, question: str) -> Prediction: + """ + Assess whether a question requires aggregation. + + Args: + question: The question to assess. + + Returns: + The prediction. + """ + decision = self.decide(question=question).decision + return Prediction(decision=decision.lower() == "true") diff --git a/extra/prompt_tuning/tuning/signatures/iql.py b/extra/prompt_tuning/tuning/signatures/iql.py index 273edf60..d3a6c531 100644 --- a/extra/prompt_tuning/tuning/signatures/iql.py +++ b/extra/prompt_tuning/tuning/signatures/iql.py @@ -18,3 +18,22 @@ class CheckQuestionFiltering(Signature): "(Respond with True or False)" ), ) + + +class CheckQuestionAggregation(Signature): + """ + Given a question, determine whether the answer requires initial data filtering in order to compute it. + Initial data filtering is a process in which the result set is reduced to only include the rows that + meet certain criteria specified in the question. + """ + + question = InputField( + prefix="Question: ", + ) + decision = OutputField( + prefix="Decision: ", + desc=( + "indicates whether the answer to the question requires initial data filtering. " + "(Respond with True or False)" + ), + ) From 596868e4ca2933050810706d534c796a2592d010 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Sat, 17 Aug 2024 03:05:08 +0200 Subject: [PATCH 02/13] fix types issues --- extra/prompt_tuning/tuning/loaders.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/extra/prompt_tuning/tuning/loaders.py b/extra/prompt_tuning/tuning/loaders.py index 2cc7dc96..9553643c 100644 --- a/extra/prompt_tuning/tuning/loaders.py +++ b/extra/prompt_tuning/tuning/loaders.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod -from typing import Dict, Iterable, List +from typing import Iterable, List import dspy.datasets from dspy import Example +from omegaconf import DictConfig class DataLoader(ABC): @@ -10,7 +11,7 @@ class DataLoader(ABC): Data loader. """ - def __init__(self, config: Dict) -> None: + def __init__(self, config: DictConfig) -> None: self.config = config @abstractmethod From 2765afa8edb1cb931c64cc6f447d00332a65b263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 26 Aug 2024 10:16:29 +0200 Subject: [PATCH 03/13] update benchmark dataset --- extra/prompt_tuning/config/data/superhero.yaml | 2 +- extra/prompt_tuning/tuning/programs/__init__.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/extra/prompt_tuning/config/data/superhero.yaml b/extra/prompt_tuning/config/data/superhero.yaml index 23412721..50a8ba2f 100644 --- a/extra/prompt_tuning/config/data/superhero.yaml +++ b/extra/prompt_tuning/config/data/superhero.yaml @@ -1,4 +1,4 @@ -path: "micpst/bird-iql" +path: "deepsense-ai/bird-iql" split: "dev" db_ids: ["superhero"] difficulties: ["simple", "moderate", "challenging"] diff --git a/extra/prompt_tuning/tuning/programs/__init__.py b/extra/prompt_tuning/tuning/programs/__init__.py index 1961d77d..96e6c520 100644 --- a/extra/prompt_tuning/tuning/programs/__init__.py +++ b/extra/prompt_tuning/tuning/programs/__init__.py @@ -1,8 +1,16 @@ -from .iql import FilteringAssessorBaseline, FilteringAssessorCoT +from .iql import AggregationAssessorBaseline, AggregationAssessorCoT, FilteringAssessorBaseline, FilteringAssessorCoT PROGRAMS = { FilteringAssessorBaseline.__name__: FilteringAssessorBaseline, FilteringAssessorCoT.__name__: FilteringAssessorCoT, + AggregationAssessorBaseline.__name__: AggregationAssessorBaseline, + AggregationAssessorCoT.__name__: AggregationAssessorCoT, } -__all__ = ["PROGRAMS", "FilteringAssessorBaseline", "FilteringAssessorCoT"] +__all__ = [ + "PROGRAMS", + "AggregationAssessorBaseline", + "AggregationAssessorCoT", + "FilteringAssessorBaseline", + "FilteringAssessorCoT", +] From 80c932c01ac7436ee333f58c4b2ee1f31af72191 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 26 Aug 2024 12:07:00 +0200 Subject: [PATCH 04/13] correct prompt --- extra/prompt_tuning/tuning/signatures/iql.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/extra/prompt_tuning/tuning/signatures/iql.py b/extra/prompt_tuning/tuning/signatures/iql.py index d3a6c531..1834604c 100644 --- a/extra/prompt_tuning/tuning/signatures/iql.py +++ b/extra/prompt_tuning/tuning/signatures/iql.py @@ -22,9 +22,8 @@ class CheckQuestionFiltering(Signature): class CheckQuestionAggregation(Signature): """ - Given a question, determine whether the answer requires initial data filtering in order to compute it. - Initial data filtering is a process in which the result set is reduced to only include the rows that - meet certain criteria specified in the question. + Given a question, determine whether the answer requires data aggregation in order to compute it. + Data aggregation is a process in which we calculate a single values for a group of rows in the result set. """ question = InputField( @@ -33,7 +32,6 @@ class CheckQuestionAggregation(Signature): decision = OutputField( prefix="Decision: ", desc=( - "indicates whether the answer to the question requires initial data filtering. " - "(Respond with True or False)" + "indicates whether the answer to the question requires data aggregation. " "(Respond with True or False)" ), ) From 2bea89d2131d19132318e8776877d9c8e8142822 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 26 Aug 2024 12:10:13 +0200 Subject: [PATCH 05/13] update README --- extra/prompt_tuning/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/extra/prompt_tuning/README.md b/extra/prompt_tuning/README.md index e5c86a5a..c3e20e86 100644 --- a/extra/prompt_tuning/README.md +++ b/extra/prompt_tuning/README.md @@ -3,6 +3,7 @@ This folder contains scripts for prompt tuning and evaluation. Prompts (programs) used in dbally: - `FILTERING_ASSESSOR` - assesses whether a question requires filtering. +- `AGGREGATION_ASSESSOR` - assesses whether a question requires aggregation. All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.io/) dataset. For now, one configuration is available to run the suite against the `superhero` database. @@ -20,6 +21,10 @@ Test multiple programs: python evaluate.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot ``` +```bash +python evaluate.py --multirun program=aggregation-assessor-baseline,aggregation-assessor-cot +``` + Compare prompt performance on multiple LLMs: ```bash From 91f262ca1dd33926d704d85e361c8ff71662d505 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 28 Aug 2024 17:16:16 +0200 Subject: [PATCH 06/13] add train script --- .../config/{config.yaml => evaluate.yaml} | 1 + extra/prompt_tuning/config/train.yaml | 10 +++ extra/prompt_tuning/evaluate.py | 32 ++------- extra/prompt_tuning/train.py | 71 +++++++++++++++++++ extra/prompt_tuning/tuning/__init__.py | 24 +++++++ extra/prompt_tuning/tuning/signatures/iql.py | 4 +- 6 files changed, 112 insertions(+), 30 deletions(-) rename extra/prompt_tuning/config/{config.yaml => evaluate.yaml} (88%) create mode 100644 extra/prompt_tuning/config/train.yaml create mode 100644 extra/prompt_tuning/train.py diff --git a/extra/prompt_tuning/config/config.yaml b/extra/prompt_tuning/config/evaluate.yaml similarity index 88% rename from extra/prompt_tuning/config/config.yaml rename to extra/prompt_tuning/config/evaluate.yaml index 9aed0232..50ede7be 100644 --- a/extra/prompt_tuning/config/config.yaml +++ b/extra/prompt_tuning/config/evaluate.yaml @@ -4,4 +4,5 @@ defaults: - program: filtering-assessor-baseline - _self_ +num_threads: 32 neptune: False diff --git a/extra/prompt_tuning/config/train.yaml b/extra/prompt_tuning/config/train.yaml new file mode 100644 index 00000000..19668567 --- /dev/null +++ b/extra/prompt_tuning/config/train.yaml @@ -0,0 +1,10 @@ +defaults: + - data: superhero + - llm: gpt-3.5-turbo + - program: filtering-assessor-baseline + - _self_ + +num_threads: 32 +breadth: 3 +depth: 10 +init_temperature: 1.4 diff --git a/extra/prompt_tuning/evaluate.py b/extra/prompt_tuning/evaluate.py index 0ae140e4..58bcdc8b 100644 --- a/extra/prompt_tuning/evaluate.py +++ b/extra/prompt_tuning/evaluate.py @@ -1,6 +1,5 @@ import asyncio import logging -from enum import Enum from pathlib import Path import dspy @@ -9,8 +8,7 @@ from dspy.evaluate import Evaluate from neptune.utils import stringify_unsupported from omegaconf import DictConfig -from tuning.loaders import IQLGenerationDataLoader -from tuning.metrics import aggregation_assess_acc, filtering_assess_acc +from tuning import DATALOADERS, METRICS from tuning.programs import PROGRAMS from tuning.utils import save, serialize_results @@ -19,26 +17,6 @@ log = logging.getLogger(__name__) -class EvaluationType(Enum): - """ - Enum representing the evaluation type. - """ - - FILTERING_ASSESSOR = "FILTERING_ASSESSOR" - AGGREGATION_ASSESSOR = "AGGREGATION_ASSESSOR" - - -EVALUATION_DATALOADERS = { - EvaluationType.FILTERING_ASSESSOR.value: IQLGenerationDataLoader, - EvaluationType.AGGREGATION_ASSESSOR.value: IQLGenerationDataLoader, -} - -EVALUATION_METRICS = { - EvaluationType.FILTERING_ASSESSOR.value: filtering_assess_acc, - EvaluationType.AGGREGATION_ASSESSOR.value: aggregation_assess_acc, -} - - async def evaluate(config: DictConfig) -> None: """ Function running evaluation for all datasets and evaluation tasks defined in hydra config. @@ -48,8 +26,8 @@ async def evaluate(config: DictConfig) -> None: """ log.info("Starting evaluation: %s", config.program.name) - dataloader = EVALUATION_DATALOADERS[config.program.type](config) - metric = EVALUATION_METRICS[config.program.type] + dataloader = DATALOADERS[config.program.type](config) + metric = METRICS[config.program.type] program = PROGRAMS[config.program.name]() dataset = await dataloader.load() @@ -60,7 +38,7 @@ async def evaluate(config: DictConfig) -> None: evaluator = Evaluate( devset=dataset, metric=metric, - num_threads=32, + num_threads=config.num_threads, display_progress=True, return_outputs=True, ) @@ -89,7 +67,7 @@ async def evaluate(config: DictConfig) -> None: run["evaluation/results.json"].upload(results_file.as_posix()) -@hydra.main(config_path="config", config_name="config", version_base="3.2") +@hydra.main(config_path="config", config_name="evaluate", version_base="3.2") def main(config: DictConfig) -> None: """ Function running evaluation for all datasets and evaluation tasks defined in hydra config. diff --git a/extra/prompt_tuning/train.py b/extra/prompt_tuning/train.py new file mode 100644 index 00000000..8ecb6657 --- /dev/null +++ b/extra/prompt_tuning/train.py @@ -0,0 +1,71 @@ +import asyncio +import logging +from pathlib import Path + +import dspy +import hydra +from dspy.teleprompt import COPRO +from omegaconf import DictConfig +from tuning import DATALOADERS, METRICS +from tuning.programs import PROGRAMS + +logging.getLogger("httpx").setLevel(logging.ERROR) +logging.getLogger("anthropic").setLevel(logging.ERROR) +log = logging.getLogger(__name__) + + +async def train(config: DictConfig) -> None: + """ + Function running training for all datasets and training tasks defined in hydra config. + + Args: + config: Hydra configuration. + """ + log.info("Starting training: %s", config.program.name) + + dataloader = DATALOADERS[config.program.type](config) + metric = METRICS[config.program.type] + program = PROGRAMS[config.program.name]() + + dataset = await dataloader.load() + + lm = dspy.__dict__[config.llm.provider](model=config.llm.model_name) + dspy.settings.configure(lm=lm) + + copro = COPRO( + metric=metric, + breadth=config.breadth, + depth=config.depth, + init_temperature=config.init_temperature, + ) + compiled_program = copro.compile( + student=program, + trainset=dataset, + eval_kwargs={ + "num_threads": config.num_threads, + "display_progress": True, + }, + ) + + log.info("Training finished. Saving compiled program...") + + output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) + program_file = output_dir / f"{program.__class__.__name__}Optimized.json" + compiled_program.save(program_file) + + log.info("Compiled program saved under directory: %s", output_dir) + + +@hydra.main(config_path="config", config_name="train", version_base="3.2") +def main(config: DictConfig) -> None: + """ + Function running evaluation for all datasets and evaluation tasks defined in hydra config. + + Args: + config: Hydra configuration. + """ + asyncio.run(train(config)) + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/extra/prompt_tuning/tuning/__init__.py b/extra/prompt_tuning/tuning/__init__.py index e69de29b..54cad5b9 100644 --- a/extra/prompt_tuning/tuning/__init__.py +++ b/extra/prompt_tuning/tuning/__init__.py @@ -0,0 +1,24 @@ +from enum import Enum + +from .loaders import IQLGenerationDataLoader +from .metrics import aggregation_assess_acc, filtering_assess_acc + + +class ProgramType(Enum): + """ + Program types. + """ + + FILTERING_ASSESSOR = "FILTERING_ASSESSOR" + AGGREGATION_ASSESSOR = "AGGREGATION_ASSESSOR" + + +DATALOADERS = { + ProgramType.FILTERING_ASSESSOR.value: IQLGenerationDataLoader, + ProgramType.AGGREGATION_ASSESSOR.value: IQLGenerationDataLoader, +} + +METRICS = { + ProgramType.FILTERING_ASSESSOR.value: filtering_assess_acc, + ProgramType.AGGREGATION_ASSESSOR.value: aggregation_assess_acc, +} diff --git a/extra/prompt_tuning/tuning/signatures/iql.py b/extra/prompt_tuning/tuning/signatures/iql.py index 1834604c..844dcfd7 100644 --- a/extra/prompt_tuning/tuning/signatures/iql.py +++ b/extra/prompt_tuning/tuning/signatures/iql.py @@ -31,7 +31,5 @@ class CheckQuestionAggregation(Signature): ) decision = OutputField( prefix="Decision: ", - desc=( - "indicates whether the answer to the question requires data aggregation. " "(Respond with True or False)" - ), + desc=("indicates whether the answer to the question requires data aggregation. (Respond with True or False)"), ) From bdd1159ba8b7d3c3f5a9df32736392fcb1419ccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 28 Aug 2024 17:53:34 +0200 Subject: [PATCH 07/13] add dynamic optimizer config --- extra/prompt_tuning/config/optimizer/copro.yaml | 6 ++++++ extra/prompt_tuning/config/train.yaml | 4 +--- extra/prompt_tuning/train.py | 14 +++++--------- 3 files changed, 12 insertions(+), 12 deletions(-) create mode 100644 extra/prompt_tuning/config/optimizer/copro.yaml diff --git a/extra/prompt_tuning/config/optimizer/copro.yaml b/extra/prompt_tuning/config/optimizer/copro.yaml new file mode 100644 index 00000000..0e48ee9a --- /dev/null +++ b/extra/prompt_tuning/config/optimizer/copro.yaml @@ -0,0 +1,6 @@ +name: COPRO +params: + breadth: 3 + depth: 10 + init_temperature: 1.4 +compile: diff --git a/extra/prompt_tuning/config/train.yaml b/extra/prompt_tuning/config/train.yaml index 19668567..a95260c9 100644 --- a/extra/prompt_tuning/config/train.yaml +++ b/extra/prompt_tuning/config/train.yaml @@ -2,9 +2,7 @@ defaults: - data: superhero - llm: gpt-3.5-turbo - program: filtering-assessor-baseline + - optimizer: copro - _self_ num_threads: 32 -breadth: 3 -depth: 10 -init_temperature: 1.4 diff --git a/extra/prompt_tuning/train.py b/extra/prompt_tuning/train.py index 8ecb6657..a8d16cf4 100644 --- a/extra/prompt_tuning/train.py +++ b/extra/prompt_tuning/train.py @@ -3,8 +3,8 @@ from pathlib import Path import dspy +import dspy.teleprompt import hydra -from dspy.teleprompt import COPRO from omegaconf import DictConfig from tuning import DATALOADERS, METRICS from tuning.programs import PROGRAMS @@ -21,7 +21,7 @@ async def train(config: DictConfig) -> None: Args: config: Hydra configuration. """ - log.info("Starting training: %s", config.program.name) + log.info("Starting training %s with %s", config.program.name, config.optimizer.name) dataloader = DATALOADERS[config.program.type](config) metric = METRICS[config.program.type] @@ -32,19 +32,15 @@ async def train(config: DictConfig) -> None: lm = dspy.__dict__[config.llm.provider](model=config.llm.model_name) dspy.settings.configure(lm=lm) - copro = COPRO( - metric=metric, - breadth=config.breadth, - depth=config.depth, - init_temperature=config.init_temperature, - ) - compiled_program = copro.compile( + optimizer = dspy.teleprompt.__dict__[config.optimizer.name](metric=metric, **config.optimizer.params) + compiled_program = optimizer.compile( student=program, trainset=dataset, eval_kwargs={ "num_threads": config.num_threads, "display_progress": True, }, + **(config.optimizer.compile or {}), ) log.info("Training finished. Saving compiled program...") From 111fe8c249f5ef3f415de4e9eb3df8c536b1277a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 28 Aug 2024 19:03:18 +0200 Subject: [PATCH 08/13] update docs --- extra/prompt_tuning/README.md | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/extra/prompt_tuning/README.md b/extra/prompt_tuning/README.md index c3e20e86..a7be3e60 100644 --- a/extra/prompt_tuning/README.md +++ b/extra/prompt_tuning/README.md @@ -9,13 +9,39 @@ All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.i ## Usage +### Train new prompts + +Tune `filtering-assessor-baseline` prompt using [COPRO](https://dspy-docs.vercel.app/docs/deep-dive/teleprompter/signature-optimizer#how-copro-works) optimizer on the `superhero` database with `gpt-3.5-turbo`: + +```bash +python train.py program=filtering-assessor-baseline +``` + +Train multiple prompts: + +```bash +python train.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot +``` + +Tweak optimizer params to get different results: + +```bash +python train.py \ + program=filtering-assessor-baseline \ + optimizer.params.breadth=2 \ + optimizer.params.depth=3 \ + optimizer.params.init_temperature=1.0 +``` + +### Evaluate prompts + Run evalution of filtering assessor baseline on the `superhero` database with `gpt-3.5-turbo`: ```bash python evaluate.py program=filtering-assessor-baseline ``` -Test multiple programs: +Test multiple prompts: ```bash python evaluate.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot @@ -31,7 +57,7 @@ Compare prompt performance on multiple LLMs: python evaluate.py --multirun program=filtering-assessor-baseline llm=gpt-3.5-turbo,claude-3.5-sonnet ``` -### Log to Neptune +#### Log to Neptune Before running the evaluation with Neptune, configure the following environment variables: From dde44524f8d630f4cbad28311bf523eb1fcfebc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 30 Aug 2024 18:20:26 +0200 Subject: [PATCH 09/13] refactor --- extra/prompt_tuning/README.md | 28 ++++++++---- extra/prompt_tuning/config/evaluate.yaml | 2 +- .../aggregation-assessor-baseline.yaml | 2 - .../program/aggregation-assessor-cot.yaml | 2 - .../program/filtering-assessor-baseline.yaml | 2 - .../program/filtering-assessor-cot.yaml | 2 - .../config/prompt/program/cot.yaml | 1 + .../config/prompt/program/predict.yaml | 1 + extra/prompt_tuning/config/prompt/prompt.yaml | 8 ++++ .../config/prompt/signature/baseline.yaml | 1 + .../config/prompt/signature/optimized.yaml | 1 + .../prompt/type/aggregation-assessor.yaml | 1 + .../prompt/type/filtering-assessor.yaml | 1 + extra/prompt_tuning/config/train.yaml | 2 +- extra/prompt_tuning/evaluate.py | 13 ++++-- extra/prompt_tuning/train.py | 13 ++++-- extra/prompt_tuning/tuning/__init__.py | 4 +- .../prompt_tuning/tuning/programs/__init__.py | 10 ++--- extra/prompt_tuning/tuning/programs/iql.py | 24 +++++----- .../tuning/signatures/__init__.py | 27 ++++++++++- extra/prompt_tuning/tuning/signatures/iql.py | 45 ++++++++++++++++--- 21 files changed, 137 insertions(+), 53 deletions(-) delete mode 100644 extra/prompt_tuning/config/program/aggregation-assessor-baseline.yaml delete mode 100644 extra/prompt_tuning/config/program/aggregation-assessor-cot.yaml delete mode 100644 extra/prompt_tuning/config/program/filtering-assessor-baseline.yaml delete mode 100644 extra/prompt_tuning/config/program/filtering-assessor-cot.yaml create mode 100644 extra/prompt_tuning/config/prompt/program/cot.yaml create mode 100644 extra/prompt_tuning/config/prompt/program/predict.yaml create mode 100644 extra/prompt_tuning/config/prompt/prompt.yaml create mode 100644 extra/prompt_tuning/config/prompt/signature/baseline.yaml create mode 100644 extra/prompt_tuning/config/prompt/signature/optimized.yaml create mode 100644 extra/prompt_tuning/config/prompt/type/aggregation-assessor.yaml create mode 100644 extra/prompt_tuning/config/prompt/type/filtering-assessor.yaml diff --git a/extra/prompt_tuning/README.md b/extra/prompt_tuning/README.md index a7be3e60..93e6afb2 100644 --- a/extra/prompt_tuning/README.md +++ b/extra/prompt_tuning/README.md @@ -11,23 +11,25 @@ All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.i ### Train new prompts -Tune `filtering-assessor-baseline` prompt using [COPRO](https://dspy-docs.vercel.app/docs/deep-dive/teleprompter/signature-optimizer#how-copro-works) optimizer on the `superhero` database with `gpt-3.5-turbo`: +Tune `filtering-assessor` prompt on base signature using [COPRO](https://dspy-docs.vercel.app/docs/deep-dive/teleprompter/signature-optimizer#how-copro-works) optimizer on the `superhero` database with `gpt-3.5-turbo`: ```bash -python train.py program=filtering-assessor-baseline +python train.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict ``` Train multiple prompts: ```bash -python train.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot +python train.py --multirun \ + prompt/type=filtering-assessor \ + prompt/signature=baseline \ + prompt/program=predict,cot ``` Tweak optimizer params to get different results: ```bash python train.py \ - program=filtering-assessor-baseline \ optimizer.params.breadth=2 \ optimizer.params.depth=3 \ optimizer.params.init_temperature=1.0 @@ -38,23 +40,33 @@ python train.py \ Run evalution of filtering assessor baseline on the `superhero` database with `gpt-3.5-turbo`: ```bash -python evaluate.py program=filtering-assessor-baseline +python evaluate.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict ``` Test multiple prompts: ```bash -python evaluate.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot +python evaluate.py --multirun \ + prompt/type=filtering-assessor \ + prompt/signature=baseline \ + prompt/program=predict,cot ``` ```bash -python evaluate.py --multirun program=aggregation-assessor-baseline,aggregation-assessor-cot +python evaluate.py --multirun \ + prompt/type=aggregation-assessor \ + prompt/signature=baseline \ + prompt/program=predict,cot ``` Compare prompt performance on multiple LLMs: ```bash -python evaluate.py --multirun program=filtering-assessor-baseline llm=gpt-3.5-turbo,claude-3.5-sonnet +python evaluate.py --multirun \ + prompt/type=filtering-assessor \ + prompt/signature=baseline \ + prompt/program=predict \ + llm=gpt-3.5-turbo,claude-3.5-sonnet ``` #### Log to Neptune diff --git a/extra/prompt_tuning/config/evaluate.yaml b/extra/prompt_tuning/config/evaluate.yaml index 50ede7be..cd72cc62 100644 --- a/extra/prompt_tuning/config/evaluate.yaml +++ b/extra/prompt_tuning/config/evaluate.yaml @@ -1,7 +1,7 @@ defaults: - data: superhero - llm: gpt-3.5-turbo - - program: filtering-assessor-baseline + - prompt: prompt - _self_ num_threads: 32 diff --git a/extra/prompt_tuning/config/program/aggregation-assessor-baseline.yaml b/extra/prompt_tuning/config/program/aggregation-assessor-baseline.yaml deleted file mode 100644 index 98d1d68b..00000000 --- a/extra/prompt_tuning/config/program/aggregation-assessor-baseline.yaml +++ /dev/null @@ -1,2 +0,0 @@ -type: AGGREGATION_ASSESSOR -name: AggregationAssessorBaseline diff --git a/extra/prompt_tuning/config/program/aggregation-assessor-cot.yaml b/extra/prompt_tuning/config/program/aggregation-assessor-cot.yaml deleted file mode 100644 index f6a7be97..00000000 --- a/extra/prompt_tuning/config/program/aggregation-assessor-cot.yaml +++ /dev/null @@ -1,2 +0,0 @@ -type: AGGREGATION_ASSESSOR -name: AggregationAssessorCoT diff --git a/extra/prompt_tuning/config/program/filtering-assessor-baseline.yaml b/extra/prompt_tuning/config/program/filtering-assessor-baseline.yaml deleted file mode 100644 index e0e6855d..00000000 --- a/extra/prompt_tuning/config/program/filtering-assessor-baseline.yaml +++ /dev/null @@ -1,2 +0,0 @@ -type: FILTERING_ASSESSOR -name: FilteringAssessorBaseline diff --git a/extra/prompt_tuning/config/program/filtering-assessor-cot.yaml b/extra/prompt_tuning/config/program/filtering-assessor-cot.yaml deleted file mode 100644 index bd7f8850..00000000 --- a/extra/prompt_tuning/config/program/filtering-assessor-cot.yaml +++ /dev/null @@ -1,2 +0,0 @@ -type: FILTERING_ASSESSOR -name: FilteringAssessorCoT diff --git a/extra/prompt_tuning/config/prompt/program/cot.yaml b/extra/prompt_tuning/config/prompt/program/cot.yaml new file mode 100644 index 00000000..3b09b5e1 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/program/cot.yaml @@ -0,0 +1 @@ +id: CoT diff --git a/extra/prompt_tuning/config/prompt/program/predict.yaml b/extra/prompt_tuning/config/prompt/program/predict.yaml new file mode 100644 index 00000000..e6d4044f --- /dev/null +++ b/extra/prompt_tuning/config/prompt/program/predict.yaml @@ -0,0 +1 @@ +id: Predict diff --git a/extra/prompt_tuning/config/prompt/prompt.yaml b/extra/prompt_tuning/config/prompt/prompt.yaml new file mode 100644 index 00000000..e4018cea --- /dev/null +++ b/extra/prompt_tuning/config/prompt/prompt.yaml @@ -0,0 +1,8 @@ +defaults: + - type: filtering-assessor + - signature: baseline + - program: predict + - _self_ + +num_threads: 32 +neptune: False diff --git a/extra/prompt_tuning/config/prompt/signature/baseline.yaml b/extra/prompt_tuning/config/prompt/signature/baseline.yaml new file mode 100644 index 00000000..834aab90 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/signature/baseline.yaml @@ -0,0 +1 @@ +id: Baseline diff --git a/extra/prompt_tuning/config/prompt/signature/optimized.yaml b/extra/prompt_tuning/config/prompt/signature/optimized.yaml new file mode 100644 index 00000000..22721da5 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/signature/optimized.yaml @@ -0,0 +1 @@ +id: Optimized diff --git a/extra/prompt_tuning/config/prompt/type/aggregation-assessor.yaml b/extra/prompt_tuning/config/prompt/type/aggregation-assessor.yaml new file mode 100644 index 00000000..ff1f2279 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/type/aggregation-assessor.yaml @@ -0,0 +1 @@ +id: AggregationAssessor diff --git a/extra/prompt_tuning/config/prompt/type/filtering-assessor.yaml b/extra/prompt_tuning/config/prompt/type/filtering-assessor.yaml new file mode 100644 index 00000000..24c4bd23 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/type/filtering-assessor.yaml @@ -0,0 +1 @@ +id: FilteringAssessor diff --git a/extra/prompt_tuning/config/train.yaml b/extra/prompt_tuning/config/train.yaml index a95260c9..e0d6cefb 100644 --- a/extra/prompt_tuning/config/train.yaml +++ b/extra/prompt_tuning/config/train.yaml @@ -1,7 +1,7 @@ defaults: - data: superhero - llm: gpt-3.5-turbo - - program: filtering-assessor-baseline + - prompt: prompt - optimizer: copro - _self_ diff --git a/extra/prompt_tuning/evaluate.py b/extra/prompt_tuning/evaluate.py index 58bcdc8b..6c03c401 100644 --- a/extra/prompt_tuning/evaluate.py +++ b/extra/prompt_tuning/evaluate.py @@ -10,6 +10,7 @@ from omegaconf import DictConfig from tuning import DATALOADERS, METRICS from tuning.programs import PROGRAMS +from tuning.signatures import SIGNATURES from tuning.utils import save, serialize_results logging.getLogger("httpx").setLevel(logging.ERROR) @@ -24,11 +25,15 @@ async def evaluate(config: DictConfig) -> None: Args: config: Hydra configuration. """ - log.info("Starting evaluation: %s", config.program.name) + signature_name = f"{config.prompt.type.id}{config.prompt.signature.id}" + program_name = f"{config.prompt.type.id}{config.prompt.program.id}" - dataloader = DATALOADERS[config.program.type](config) - metric = METRICS[config.program.type] - program = PROGRAMS[config.program.name]() + log.info("Starting evaluation: %s(%s) program", program_name, signature_name) + + dataloader = DATALOADERS[config.prompt.type.id](config) + metric = METRICS[config.prompt.type.id] + signature = SIGNATURES[signature_name] + program = PROGRAMS[program_name](signature) dataset = await dataloader.load() diff --git a/extra/prompt_tuning/train.py b/extra/prompt_tuning/train.py index a8d16cf4..2d91426b 100644 --- a/extra/prompt_tuning/train.py +++ b/extra/prompt_tuning/train.py @@ -8,6 +8,7 @@ from omegaconf import DictConfig from tuning import DATALOADERS, METRICS from tuning.programs import PROGRAMS +from tuning.signatures import SIGNATURES logging.getLogger("httpx").setLevel(logging.ERROR) logging.getLogger("anthropic").setLevel(logging.ERROR) @@ -21,11 +22,15 @@ async def train(config: DictConfig) -> None: Args: config: Hydra configuration. """ - log.info("Starting training %s with %s", config.program.name, config.optimizer.name) + signature_name = f"{config.prompt.type.id}{config.prompt.signature.id}" + program_name = f"{config.prompt.type.id}{config.prompt.program.id}" - dataloader = DATALOADERS[config.program.type](config) - metric = METRICS[config.program.type] - program = PROGRAMS[config.program.name]() + log.info("Starting training: %s(%s) program with %s optimizer", program_name, signature_name, config.optimizer.name) + + dataloader = DATALOADERS[config.prompt.type.id](config) + metric = METRICS[config.prompt.type.id] + signature = SIGNATURES[signature_name] + program = PROGRAMS[program_name](signature) dataset = await dataloader.load() diff --git a/extra/prompt_tuning/tuning/__init__.py b/extra/prompt_tuning/tuning/__init__.py index 54cad5b9..f41fd945 100644 --- a/extra/prompt_tuning/tuning/__init__.py +++ b/extra/prompt_tuning/tuning/__init__.py @@ -9,8 +9,8 @@ class ProgramType(Enum): Program types. """ - FILTERING_ASSESSOR = "FILTERING_ASSESSOR" - AGGREGATION_ASSESSOR = "AGGREGATION_ASSESSOR" + FILTERING_ASSESSOR = "FilteringAssessor" + AGGREGATION_ASSESSOR = "AggregationAssessor" DATALOADERS = { diff --git a/extra/prompt_tuning/tuning/programs/__init__.py b/extra/prompt_tuning/tuning/programs/__init__.py index 96e6c520..6356e7c0 100644 --- a/extra/prompt_tuning/tuning/programs/__init__.py +++ b/extra/prompt_tuning/tuning/programs/__init__.py @@ -1,16 +1,16 @@ -from .iql import AggregationAssessorBaseline, AggregationAssessorCoT, FilteringAssessorBaseline, FilteringAssessorCoT +from .iql import AggregationAssessorCoT, AggregationAssessorPredict, FilteringAssessorCoT, FilteringAssessorPredict PROGRAMS = { - FilteringAssessorBaseline.__name__: FilteringAssessorBaseline, + FilteringAssessorPredict.__name__: FilteringAssessorPredict, FilteringAssessorCoT.__name__: FilteringAssessorCoT, - AggregationAssessorBaseline.__name__: AggregationAssessorBaseline, + AggregationAssessorPredict.__name__: AggregationAssessorPredict, AggregationAssessorCoT.__name__: AggregationAssessorCoT, } __all__ = [ "PROGRAMS", - "AggregationAssessorBaseline", + "AggregationAssessorPredict", "AggregationAssessorCoT", - "FilteringAssessorBaseline", + "FilteringAssessorPredict", "FilteringAssessorCoT", ] diff --git a/extra/prompt_tuning/tuning/programs/iql.py b/extra/prompt_tuning/tuning/programs/iql.py index a7209672..0d61a0b4 100644 --- a/extra/prompt_tuning/tuning/programs/iql.py +++ b/extra/prompt_tuning/tuning/programs/iql.py @@ -1,16 +1,18 @@ +from typing import Type + from dspy import ChainOfThought, Module, Predict, Prediction -from ..signatures.iql import CheckQuestionAggregation, CheckQuestionFiltering +from ..signatures.iql import AggregationAssessor, FilteringAssessor -class FilteringAssessorBaseline(Module): +class FilteringAssessorPredict(Module): """ Program that assesses whether a question requires filtering. """ - def __init__(self) -> None: + def __init__(self, signature: Type[FilteringAssessor]) -> None: super().__init__() - self.decide = Predict(CheckQuestionFiltering) + self.decide = Predict(signature) def forward(self, question: str) -> Prediction: """ @@ -31,9 +33,9 @@ class FilteringAssessorCoT(Module): Program that assesses whether a question requires filtering. """ - def __init__(self) -> None: + def __init__(self, signature: Type[FilteringAssessor]) -> None: super().__init__() - self.decide = ChainOfThought(CheckQuestionFiltering) + self.decide = ChainOfThought(signature) def forward(self, question: str) -> Prediction: """ @@ -49,14 +51,14 @@ def forward(self, question: str) -> Prediction: return Prediction(decision=decision.lower() == "true") -class AggregationAssessorBaseline(Module): +class AggregationAssessorPredict(Module): """ Program that assesses whether a question requires aggregation. """ - def __init__(self) -> None: + def __init__(self, signature: Type[AggregationAssessor]) -> None: super().__init__() - self.decide = Predict(CheckQuestionAggregation) + self.decide = Predict(signature) def forward(self, question: str) -> Prediction: """ @@ -77,9 +79,9 @@ class AggregationAssessorCoT(Module): Program that assesses whether a question requires aggregation. """ - def __init__(self) -> None: + def __init__(self, signature: Type[AggregationAssessor]) -> None: super().__init__() - self.decide = ChainOfThought(CheckQuestionAggregation) + self.decide = ChainOfThought(signature) def forward(self, question: str) -> Prediction: """ diff --git a/extra/prompt_tuning/tuning/signatures/__init__.py b/extra/prompt_tuning/tuning/signatures/__init__.py index dd3be583..a7c50ad1 100644 --- a/extra/prompt_tuning/tuning/signatures/__init__.py +++ b/extra/prompt_tuning/tuning/signatures/__init__.py @@ -1,3 +1,26 @@ -from .iql import CheckQuestionFiltering +from .iql import ( + AggregationAssessor, + AggregationAssessorBaseline, + AggregationAssessorOptimized, + FilteringAssessor, + FilteringAssessorBaseline, + FilteringAssessorOptimized, +) -__all__ = ["CheckQuestionFiltering"] +SIGNATURES = { + AggregationAssessorBaseline.__name__: AggregationAssessorBaseline, + AggregationAssessorOptimized.__name__: AggregationAssessorOptimized, + FilteringAssessorBaseline.__name__: FilteringAssessorBaseline, + FilteringAssessorOptimized.__name__: FilteringAssessorOptimized, +} + + +__all__ = [ + "AggregationAssessor", + "FilteringAssessor", + "AggregationAssessorBaseline", + "AggregationAssessorOptimized", + "FilteringAssessorBaseline", + "FilteringAssessorOptimized", + "SIGNATURES", +] diff --git a/extra/prompt_tuning/tuning/signatures/iql.py b/extra/prompt_tuning/tuning/signatures/iql.py index 844dcfd7..680bd4f9 100644 --- a/extra/prompt_tuning/tuning/signatures/iql.py +++ b/extra/prompt_tuning/tuning/signatures/iql.py @@ -1,11 +1,12 @@ +from abc import ABC + from dspy import InputField, OutputField, Signature -class CheckQuestionFiltering(Signature): +class FilteringAssessor(Signature, ABC): """ - Given a question, determine whether the answer requires initial data filtering in order to compute it. - Initial data filtering is a process in which the result set is reduced to only include the rows that - meet certain criteria specified in the question. + Abstract signature, should be implemented by a concrete signature, + describing the actual task for the LLM. """ question = InputField( @@ -20,10 +21,26 @@ class CheckQuestionFiltering(Signature): ) -class CheckQuestionAggregation(Signature): +class FilteringAssessorBaseline(FilteringAssessor): """ - Given a question, determine whether the answer requires data aggregation in order to compute it. - Data aggregation is a process in which we calculate a single values for a group of rows in the result set. + Given a question, determine whether the answer requires initial data filtering in order to compute it. + Initial data filtering is a process in which the result set is reduced to only include the rows that + meet certain criteria specified in the question. + """ + + +class FilteringAssessorOptimized(FilteringAssessor): + """ + Given a question, determine whether the answer requires initial data filtering in order to compute it. + Initial data filtering is a process in which the result set is reduced to only include the rows that + meet certain criteria specified in the question. + """ + + +class AggregationAssessor(Signature, ABC): + """ + Abstract signature, should be implemented by a concrete signature, + describing the actual task for the LLM. """ question = InputField( @@ -33,3 +50,17 @@ class CheckQuestionAggregation(Signature): prefix="Decision: ", desc=("indicates whether the answer to the question requires data aggregation. (Respond with True or False)"), ) + + +class AggregationAssessorBaseline(AggregationAssessor): + """ + Given a question, determine whether the answer requires data aggregation in order to compute it. + Data aggregation is a process in which we calculate a single values for a group of rows in the result set. + """ + + +class AggregationAssessorOptimized(AggregationAssessor): + """ + Given a question, determine whether the answer requires data aggregation in order to compute it. + Data aggregation is a process in which we calculate a single values for a group of rows in the result set. + """ From 42edc8e242148bcfe82a29b9d5e1b5f8aa9c67d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 30 Aug 2024 18:42:08 +0200 Subject: [PATCH 10/13] add mipro optimizer --- extra/prompt_tuning/README.md | 11 +++++++++-- extra/prompt_tuning/config/optimizer/mipro.yaml | 9 +++++++++ extra/prompt_tuning/tuning/metrics/iql.py | 6 +++--- 3 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 extra/prompt_tuning/config/optimizer/mipro.yaml diff --git a/extra/prompt_tuning/README.md b/extra/prompt_tuning/README.md index 93e6afb2..54ca81cb 100644 --- a/extra/prompt_tuning/README.md +++ b/extra/prompt_tuning/README.md @@ -2,8 +2,8 @@ This folder contains scripts for prompt tuning and evaluation. Prompts (programs) used in dbally: -- `FILTERING_ASSESSOR` - assesses whether a question requires filtering. -- `AGGREGATION_ASSESSOR` - assesses whether a question requires aggregation. +- `FilteringAssessor` - assesses whether a question requires filtering. +- `AggregationAssessor` - assesses whether a question requires aggregation. All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.io/) dataset. For now, one configuration is available to run the suite against the `superhero` database. @@ -17,6 +17,12 @@ Tune `filtering-assessor` prompt on base signature using [COPRO](https://dspy-do python train.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict ``` +Change optimizer to [MIPRO](https://dspy-docs.vercel.app/docs/cheatsheet#mipro): + +```bash +python train.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict optimizer=mipro +``` + Train multiple prompts: ```bash @@ -30,6 +36,7 @@ Tweak optimizer params to get different results: ```bash python train.py \ + optimizer=copro \ optimizer.params.breadth=2 \ optimizer.params.depth=3 \ optimizer.params.init_temperature=1.0 diff --git a/extra/prompt_tuning/config/optimizer/mipro.yaml b/extra/prompt_tuning/config/optimizer/mipro.yaml new file mode 100644 index 00000000..edf7e139 --- /dev/null +++ b/extra/prompt_tuning/config/optimizer/mipro.yaml @@ -0,0 +1,9 @@ +name: MIPRO +params: + num_candidates: 3 + init_temperature: 1.4 + +compile: + max_bootstrapped_demos: 3 + max_labeled_demos: 0 + num_trials: 10 diff --git a/extra/prompt_tuning/tuning/metrics/iql.py b/extra/prompt_tuning/tuning/metrics/iql.py index 3340f929..b2ad6a42 100644 --- a/extra/prompt_tuning/tuning/metrics/iql.py +++ b/extra/prompt_tuning/tuning/metrics/iql.py @@ -1,9 +1,9 @@ -from typing import Dict +from typing import Dict, List, Optional from dspy import Prediction -def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool: +def filtering_assess_acc(gold: Dict, pred: Prediction, _trace: Optional[List] = None) -> bool: """ IQL filtering decision metric. @@ -19,7 +19,7 @@ def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool: ) -def aggregation_assess_acc(gold: Dict, pred: Prediction) -> bool: +def aggregation_assess_acc(gold: Dict, pred: Prediction, _trace: Optional[List] = None) -> bool: """ IQL aggregation decision metric. From 3386d56a9d361dfe022e7539e9ec7f04020435cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 30 Aug 2024 19:52:40 +0200 Subject: [PATCH 11/13] improve fa signature --- .../prompt_tuning/config/optimizer/copro.yaml | 6 ++-- .../config/prompt/program/coth.yaml | 1 + extra/prompt_tuning/evaluate.py | 5 ++-- .../prompt_tuning/tuning/programs/__init__.py | 10 ++++++- extra/prompt_tuning/tuning/programs/iql.py | 28 ++++++++++++++++++- extra/prompt_tuning/tuning/signatures/iql.py | 13 ++++++--- 6 files changed, 52 insertions(+), 11 deletions(-) create mode 100644 extra/prompt_tuning/config/prompt/program/coth.yaml diff --git a/extra/prompt_tuning/config/optimizer/copro.yaml b/extra/prompt_tuning/config/optimizer/copro.yaml index 0e48ee9a..8a5b1975 100644 --- a/extra/prompt_tuning/config/optimizer/copro.yaml +++ b/extra/prompt_tuning/config/optimizer/copro.yaml @@ -1,6 +1,6 @@ name: COPRO params: - breadth: 3 - depth: 10 - init_temperature: 1.4 + breadth: 4 + depth: 15 + init_temperature: 1.5 compile: diff --git a/extra/prompt_tuning/config/prompt/program/coth.yaml b/extra/prompt_tuning/config/prompt/program/coth.yaml new file mode 100644 index 00000000..ba65ce99 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/program/coth.yaml @@ -0,0 +1 @@ +id: CoTH diff --git a/extra/prompt_tuning/evaluate.py b/extra/prompt_tuning/evaluate.py index 6c03c401..69a751b2 100644 --- a/extra/prompt_tuning/evaluate.py +++ b/extra/prompt_tuning/evaluate.py @@ -61,8 +61,9 @@ async def evaluate(config: DictConfig) -> None: run = neptune.init_run() run["sys/tags"].add( [ - config.program.type, - config.program.name, + config.prompt.type.id, + config.prompt.signature.id, + config.prompt.program.id, *config.data.db_ids, *config.data.difficulties, ] diff --git a/extra/prompt_tuning/tuning/programs/__init__.py b/extra/prompt_tuning/tuning/programs/__init__.py index 6356e7c0..221900cf 100644 --- a/extra/prompt_tuning/tuning/programs/__init__.py +++ b/extra/prompt_tuning/tuning/programs/__init__.py @@ -1,8 +1,15 @@ -from .iql import AggregationAssessorCoT, AggregationAssessorPredict, FilteringAssessorCoT, FilteringAssessorPredict +from .iql import ( + AggregationAssessorCoT, + AggregationAssessorPredict, + FilteringAssessorCoT, + FilteringAssessorCoTH, + FilteringAssessorPredict, +) PROGRAMS = { FilteringAssessorPredict.__name__: FilteringAssessorPredict, FilteringAssessorCoT.__name__: FilteringAssessorCoT, + FilteringAssessorCoTH.__name__: FilteringAssessorCoTH, AggregationAssessorPredict.__name__: AggregationAssessorPredict, AggregationAssessorCoT.__name__: AggregationAssessorCoT, } @@ -13,4 +20,5 @@ "AggregationAssessorCoT", "FilteringAssessorPredict", "FilteringAssessorCoT", + "FilteringAssessorCoTH", ] diff --git a/extra/prompt_tuning/tuning/programs/iql.py b/extra/prompt_tuning/tuning/programs/iql.py index 0d61a0b4..15aaf13e 100644 --- a/extra/prompt_tuning/tuning/programs/iql.py +++ b/extra/prompt_tuning/tuning/programs/iql.py @@ -1,6 +1,6 @@ from typing import Type -from dspy import ChainOfThought, Module, Predict, Prediction +from dspy import ChainOfThought, ChainOfThoughtWithHint, Module, Predict, Prediction from ..signatures.iql import AggregationAssessor, FilteringAssessor @@ -51,6 +51,32 @@ def forward(self, question: str) -> Prediction: return Prediction(decision=decision.lower() == "true") +class FilteringAssessorCoTH(Module): + """ + Program that assesses whether a question requires filtering. + """ + + def __init__(self, signature: Type[FilteringAssessor]) -> None: + super().__init__() + self.decide = ChainOfThoughtWithHint(signature) + + def forward(self, question: str) -> Prediction: + """ + Assess whether a question requires filtering. + + Args: + question: The question to assess. + + Returns: + The prediction. + """ + decision = self.decide( + question=question, + hint="Look for words indicating data specific features.", + ).decision + return Prediction(decision=decision.lower() == "true") + + class AggregationAssessorPredict(Module): """ Program that assesses whether a question requires aggregation. diff --git a/extra/prompt_tuning/tuning/signatures/iql.py b/extra/prompt_tuning/tuning/signatures/iql.py index 680bd4f9..73b44a94 100644 --- a/extra/prompt_tuning/tuning/signatures/iql.py +++ b/extra/prompt_tuning/tuning/signatures/iql.py @@ -32,8 +32,8 @@ class FilteringAssessorBaseline(FilteringAssessor): class FilteringAssessorOptimized(FilteringAssessor): """ Given a question, determine whether the answer requires initial data filtering in order to compute it. - Initial data filtering is a process in which the result set is reduced to only include the rows that - meet certain criteria specified in the question. + Initial data filtering is a process in which the result set is filtered based on the specific features + stated in the question. """ @@ -61,6 +61,11 @@ class AggregationAssessorBaseline(AggregationAssessor): class AggregationAssessorOptimized(AggregationAssessor): """ - Given a question, determine whether the answer requires data aggregation in order to compute it. - Data aggregation is a process in which we calculate a single values for a group of rows in the result set. + Look at the dependencies between the elements in the question and distinguish whether a single value can be obtained + for a groupof entities in the data table by aggregating necessary values. """ + + decision = OutputField( + prefix="Instructions to identify aggregated computations given a question, analyze dependencies -> ", + desc="indicates whether the answer to the question requires data aggregation. (Respond with True or False)", + ) From fa64aeeb29749192d894317aa55ddba9f6fe5ce8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 30 Aug 2024 20:23:06 +0200 Subject: [PATCH 12/13] improve aa signature --- .../prompt_tuning/tuning/programs/__init__.py | 3 +++ extra/prompt_tuning/tuning/programs/iql.py | 26 +++++++++++++++++++ extra/prompt_tuning/tuning/signatures/iql.py | 10 +++---- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/extra/prompt_tuning/tuning/programs/__init__.py b/extra/prompt_tuning/tuning/programs/__init__.py index 221900cf..54038fbf 100644 --- a/extra/prompt_tuning/tuning/programs/__init__.py +++ b/extra/prompt_tuning/tuning/programs/__init__.py @@ -1,5 +1,6 @@ from .iql import ( AggregationAssessorCoT, + AggregationAssessorCoTH, AggregationAssessorPredict, FilteringAssessorCoT, FilteringAssessorCoTH, @@ -12,6 +13,7 @@ FilteringAssessorCoTH.__name__: FilteringAssessorCoTH, AggregationAssessorPredict.__name__: AggregationAssessorPredict, AggregationAssessorCoT.__name__: AggregationAssessorCoT, + AggregationAssessorCoTH.__name__: AggregationAssessorCoTH, } __all__ = [ @@ -21,4 +23,5 @@ "FilteringAssessorPredict", "FilteringAssessorCoT", "FilteringAssessorCoTH", + "AggregationAssessorCoTH", ] diff --git a/extra/prompt_tuning/tuning/programs/iql.py b/extra/prompt_tuning/tuning/programs/iql.py index 15aaf13e..f33c444e 100644 --- a/extra/prompt_tuning/tuning/programs/iql.py +++ b/extra/prompt_tuning/tuning/programs/iql.py @@ -121,3 +121,29 @@ def forward(self, question: str) -> Prediction: """ decision = self.decide(question=question).decision return Prediction(decision=decision.lower() == "true") + + +class AggregationAssessorCoTH(Module): + """ + Program that assesses whether a question requires aggregation. + """ + + def __init__(self, signature: Type[AggregationAssessor]) -> None: + super().__init__() + self.decide = ChainOfThoughtWithHint(signature) + + def forward(self, question: str) -> Prediction: + """ + Assess whether a question requires aggregation. + + Args: + question: The question to assess. + + Returns: + The prediction. + """ + decision = self.decide( + question=question, + hint="Look for words indicating aggregation functions.", + ).decision + return Prediction(decision=decision.lower() == "true") diff --git a/extra/prompt_tuning/tuning/signatures/iql.py b/extra/prompt_tuning/tuning/signatures/iql.py index 73b44a94..ea113514 100644 --- a/extra/prompt_tuning/tuning/signatures/iql.py +++ b/extra/prompt_tuning/tuning/signatures/iql.py @@ -61,11 +61,7 @@ class AggregationAssessorBaseline(AggregationAssessor): class AggregationAssessorOptimized(AggregationAssessor): """ - Look at the dependencies between the elements in the question and distinguish whether a single value can be obtained - for a groupof entities in the data table by aggregating necessary values. + Given a question, determine whether the answer requires data aggregation in order to compute it. + Data aggregation is a process in which we calculate a single values for a group of rows in the result set. + Most common aggregation functions are counting, averaging, summing, but other types of aggregation are possible. """ - - decision = OutputField( - prefix="Instructions to identify aggregated computations given a question, analyze dependencies -> ", - desc="indicates whether the answer to the question requires data aggregation. (Respond with True or False)", - ) From 8e8f5b49dc9184380ca4afbfdeeff19eeedaa438 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 30 Aug 2024 21:35:32 +0200 Subject: [PATCH 13/13] update filtering assessor --- extra/prompt_tuning/tuning/signatures/iql.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/extra/prompt_tuning/tuning/signatures/iql.py b/extra/prompt_tuning/tuning/signatures/iql.py index ea113514..0305ce81 100644 --- a/extra/prompt_tuning/tuning/signatures/iql.py +++ b/extra/prompt_tuning/tuning/signatures/iql.py @@ -14,26 +14,25 @@ class FilteringAssessor(Signature, ABC): ) decision = OutputField( prefix="Decision: ", - desc=( - "indicates whether the answer to the question requires initial data filtering. " - "(Respond with True or False)" - ), + desc=("indicates whether the answer to the question requires data filtering. " "(Respond with True or False)"), ) class FilteringAssessorBaseline(FilteringAssessor): """ - Given a question, determine whether the answer requires initial data filtering in order to compute it. - Initial data filtering is a process in which the result set is reduced to only include the rows that + Given a question, determine whether the answer requires data filtering in order to compute it. + Data filtering is a process in which the result set is reduced to only include the rows that meet certain criteria specified in the question. """ class FilteringAssessorOptimized(FilteringAssessor): """ - Given a question, determine whether the answer requires initial data filtering in order to compute it. - Initial data filtering is a process in which the result set is filtered based on the specific features - stated in the question. + Given a question, determine whether the answer requires data filtering in order to compute it. + Data filtering is a process in which the result set is filtered based on the specific features + stated in the question. Such a question can be easily identified by using words that refer to + specific feature values (rather than feature names). Look for words indicating specific values + that the answer should contain. """