diff --git a/extra/prompt_tuning/README.md b/extra/prompt_tuning/README.md index e5c86a5a..54ca81cb 100644 --- a/extra/prompt_tuning/README.md +++ b/extra/prompt_tuning/README.md @@ -2,31 +2,81 @@ This folder contains scripts for prompt tuning and evaluation. Prompts (programs) used in dbally: -- `FILTERING_ASSESSOR` - assesses whether a question requires filtering. +- `FilteringAssessor` - assesses whether a question requires filtering. +- `AggregationAssessor` - assesses whether a question requires aggregation. All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.io/) dataset. For now, one configuration is available to run the suite against the `superhero` database. ## Usage +### Train new prompts + +Tune `filtering-assessor` prompt on base signature using [COPRO](https://dspy-docs.vercel.app/docs/deep-dive/teleprompter/signature-optimizer#how-copro-works) optimizer on the `superhero` database with `gpt-3.5-turbo`: + +```bash +python train.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict +``` + +Change optimizer to [MIPRO](https://dspy-docs.vercel.app/docs/cheatsheet#mipro): + +```bash +python train.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict optimizer=mipro +``` + +Train multiple prompts: + +```bash +python train.py --multirun \ + prompt/type=filtering-assessor \ + prompt/signature=baseline \ + prompt/program=predict,cot +``` + +Tweak optimizer params to get different results: + +```bash +python train.py \ + optimizer=copro \ + optimizer.params.breadth=2 \ + optimizer.params.depth=3 \ + optimizer.params.init_temperature=1.0 +``` + +### Evaluate prompts + Run evalution of filtering assessor baseline on the `superhero` database with `gpt-3.5-turbo`: ```bash -python evaluate.py program=filtering-assessor-baseline +python evaluate.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict ``` -Test multiple programs: +Test multiple prompts: + +```bash +python evaluate.py --multirun \ + prompt/type=filtering-assessor \ + prompt/signature=baseline \ + prompt/program=predict,cot +``` ```bash -python evaluate.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot +python evaluate.py --multirun \ + prompt/type=aggregation-assessor \ + prompt/signature=baseline \ + prompt/program=predict,cot ``` Compare prompt performance on multiple LLMs: ```bash -python evaluate.py --multirun program=filtering-assessor-baseline llm=gpt-3.5-turbo,claude-3.5-sonnet +python evaluate.py --multirun \ + prompt/type=filtering-assessor \ + prompt/signature=baseline \ + prompt/program=predict \ + llm=gpt-3.5-turbo,claude-3.5-sonnet ``` -### Log to Neptune +#### Log to Neptune Before running the evaluation with Neptune, configure the following environment variables: diff --git a/extra/prompt_tuning/config/data/superhero.yaml b/extra/prompt_tuning/config/data/superhero.yaml index 23412721..50a8ba2f 100644 --- a/extra/prompt_tuning/config/data/superhero.yaml +++ b/extra/prompt_tuning/config/data/superhero.yaml @@ -1,4 +1,4 @@ -path: "micpst/bird-iql" +path: "deepsense-ai/bird-iql" split: "dev" db_ids: ["superhero"] difficulties: ["simple", "moderate", "challenging"] diff --git a/extra/prompt_tuning/config/config.yaml b/extra/prompt_tuning/config/evaluate.yaml similarity index 66% rename from extra/prompt_tuning/config/config.yaml rename to extra/prompt_tuning/config/evaluate.yaml index 9aed0232..cd72cc62 100644 --- a/extra/prompt_tuning/config/config.yaml +++ b/extra/prompt_tuning/config/evaluate.yaml @@ -1,7 +1,8 @@ defaults: - data: superhero - llm: gpt-3.5-turbo - - program: filtering-assessor-baseline + - prompt: prompt - _self_ +num_threads: 32 neptune: False diff --git a/extra/prompt_tuning/config/optimizer/copro.yaml b/extra/prompt_tuning/config/optimizer/copro.yaml new file mode 100644 index 00000000..8a5b1975 --- /dev/null +++ b/extra/prompt_tuning/config/optimizer/copro.yaml @@ -0,0 +1,6 @@ +name: COPRO +params: + breadth: 4 + depth: 15 + init_temperature: 1.5 +compile: diff --git a/extra/prompt_tuning/config/optimizer/mipro.yaml b/extra/prompt_tuning/config/optimizer/mipro.yaml new file mode 100644 index 00000000..edf7e139 --- /dev/null +++ b/extra/prompt_tuning/config/optimizer/mipro.yaml @@ -0,0 +1,9 @@ +name: MIPRO +params: + num_candidates: 3 + init_temperature: 1.4 + +compile: + max_bootstrapped_demos: 3 + max_labeled_demos: 0 + num_trials: 10 diff --git a/extra/prompt_tuning/config/program/filtering-assessor-baseline.yaml b/extra/prompt_tuning/config/program/filtering-assessor-baseline.yaml deleted file mode 100644 index e0e6855d..00000000 --- a/extra/prompt_tuning/config/program/filtering-assessor-baseline.yaml +++ /dev/null @@ -1,2 +0,0 @@ -type: FILTERING_ASSESSOR -name: FilteringAssessorBaseline diff --git a/extra/prompt_tuning/config/program/filtering-assessor-cot.yaml b/extra/prompt_tuning/config/program/filtering-assessor-cot.yaml deleted file mode 100644 index bd7f8850..00000000 --- a/extra/prompt_tuning/config/program/filtering-assessor-cot.yaml +++ /dev/null @@ -1,2 +0,0 @@ -type: FILTERING_ASSESSOR -name: FilteringAssessorCoT diff --git a/extra/prompt_tuning/config/prompt/program/cot.yaml b/extra/prompt_tuning/config/prompt/program/cot.yaml new file mode 100644 index 00000000..3b09b5e1 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/program/cot.yaml @@ -0,0 +1 @@ +id: CoT diff --git a/extra/prompt_tuning/config/prompt/program/coth.yaml b/extra/prompt_tuning/config/prompt/program/coth.yaml new file mode 100644 index 00000000..ba65ce99 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/program/coth.yaml @@ -0,0 +1 @@ +id: CoTH diff --git a/extra/prompt_tuning/config/prompt/program/predict.yaml b/extra/prompt_tuning/config/prompt/program/predict.yaml new file mode 100644 index 00000000..e6d4044f --- /dev/null +++ b/extra/prompt_tuning/config/prompt/program/predict.yaml @@ -0,0 +1 @@ +id: Predict diff --git a/extra/prompt_tuning/config/prompt/prompt.yaml b/extra/prompt_tuning/config/prompt/prompt.yaml new file mode 100644 index 00000000..e4018cea --- /dev/null +++ b/extra/prompt_tuning/config/prompt/prompt.yaml @@ -0,0 +1,8 @@ +defaults: + - type: filtering-assessor + - signature: baseline + - program: predict + - _self_ + +num_threads: 32 +neptune: False diff --git a/extra/prompt_tuning/config/prompt/signature/baseline.yaml b/extra/prompt_tuning/config/prompt/signature/baseline.yaml new file mode 100644 index 00000000..834aab90 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/signature/baseline.yaml @@ -0,0 +1 @@ +id: Baseline diff --git a/extra/prompt_tuning/config/prompt/signature/optimized.yaml b/extra/prompt_tuning/config/prompt/signature/optimized.yaml new file mode 100644 index 00000000..22721da5 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/signature/optimized.yaml @@ -0,0 +1 @@ +id: Optimized diff --git a/extra/prompt_tuning/config/prompt/type/aggregation-assessor.yaml b/extra/prompt_tuning/config/prompt/type/aggregation-assessor.yaml new file mode 100644 index 00000000..ff1f2279 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/type/aggregation-assessor.yaml @@ -0,0 +1 @@ +id: AggregationAssessor diff --git a/extra/prompt_tuning/config/prompt/type/filtering-assessor.yaml b/extra/prompt_tuning/config/prompt/type/filtering-assessor.yaml new file mode 100644 index 00000000..24c4bd23 --- /dev/null +++ b/extra/prompt_tuning/config/prompt/type/filtering-assessor.yaml @@ -0,0 +1 @@ +id: FilteringAssessor diff --git a/extra/prompt_tuning/config/train.yaml b/extra/prompt_tuning/config/train.yaml new file mode 100644 index 00000000..e0d6cefb --- /dev/null +++ b/extra/prompt_tuning/config/train.yaml @@ -0,0 +1,8 @@ +defaults: + - data: superhero + - llm: gpt-3.5-turbo + - prompt: prompt + - optimizer: copro + - _self_ + +num_threads: 32 diff --git a/extra/prompt_tuning/evaluate.py b/extra/prompt_tuning/evaluate.py index 35bcf2e8..69a751b2 100644 --- a/extra/prompt_tuning/evaluate.py +++ b/extra/prompt_tuning/evaluate.py @@ -1,6 +1,5 @@ import asyncio import logging -from enum import Enum from pathlib import Path import dspy @@ -9,9 +8,9 @@ from dspy.evaluate import Evaluate from neptune.utils import stringify_unsupported from omegaconf import DictConfig -from tuning.loaders import IQLGenerationDataLoader -from tuning.metrics import filtering_assess_acc +from tuning import DATALOADERS, METRICS from tuning.programs import PROGRAMS +from tuning.signatures import SIGNATURES from tuning.utils import save, serialize_results logging.getLogger("httpx").setLevel(logging.ERROR) @@ -19,23 +18,6 @@ log = logging.getLogger(__name__) -class EvaluationType(Enum): - """ - Enum representing the evaluation type. - """ - - FILTERING_ASSESSOR = "FILTERING_ASSESSOR" - - -EVALUATION_DATALOADERS = { - EvaluationType.FILTERING_ASSESSOR.value: IQLGenerationDataLoader, -} - -EVALUATION_METRICS = { - EvaluationType.FILTERING_ASSESSOR.value: filtering_assess_acc, -} - - async def evaluate(config: DictConfig) -> None: """ Function running evaluation for all datasets and evaluation tasks defined in hydra config. @@ -43,11 +25,15 @@ async def evaluate(config: DictConfig) -> None: Args: config: Hydra configuration. """ - log.info("Starting evaluation: %s", config.program.name) + signature_name = f"{config.prompt.type.id}{config.prompt.signature.id}" + program_name = f"{config.prompt.type.id}{config.prompt.program.id}" + + log.info("Starting evaluation: %s(%s) program", program_name, signature_name) - dataloader = EVALUATION_DATALOADERS[config.program.type](config) - metric = EVALUATION_METRICS[config.program.type] - program = PROGRAMS[config.program.name]() + dataloader = DATALOADERS[config.prompt.type.id](config) + metric = METRICS[config.prompt.type.id] + signature = SIGNATURES[signature_name] + program = PROGRAMS[program_name](signature) dataset = await dataloader.load() @@ -57,7 +43,7 @@ async def evaluate(config: DictConfig) -> None: evaluator = Evaluate( devset=dataset, metric=metric, - num_threads=32, + num_threads=config.num_threads, display_progress=True, return_outputs=True, ) @@ -75,8 +61,9 @@ async def evaluate(config: DictConfig) -> None: run = neptune.init_run() run["sys/tags"].add( [ - config.program.type, - config.program.name, + config.prompt.type.id, + config.prompt.signature.id, + config.prompt.program.id, *config.data.db_ids, *config.data.difficulties, ] @@ -86,7 +73,7 @@ async def evaluate(config: DictConfig) -> None: run["evaluation/results.json"].upload(results_file.as_posix()) -@hydra.main(config_path="config", config_name="config", version_base="3.2") +@hydra.main(config_path="config", config_name="evaluate", version_base="3.2") def main(config: DictConfig) -> None: """ Function running evaluation for all datasets and evaluation tasks defined in hydra config. diff --git a/extra/prompt_tuning/train.py b/extra/prompt_tuning/train.py new file mode 100644 index 00000000..2d91426b --- /dev/null +++ b/extra/prompt_tuning/train.py @@ -0,0 +1,72 @@ +import asyncio +import logging +from pathlib import Path + +import dspy +import dspy.teleprompt +import hydra +from omegaconf import DictConfig +from tuning import DATALOADERS, METRICS +from tuning.programs import PROGRAMS +from tuning.signatures import SIGNATURES + +logging.getLogger("httpx").setLevel(logging.ERROR) +logging.getLogger("anthropic").setLevel(logging.ERROR) +log = logging.getLogger(__name__) + + +async def train(config: DictConfig) -> None: + """ + Function running training for all datasets and training tasks defined in hydra config. + + Args: + config: Hydra configuration. + """ + signature_name = f"{config.prompt.type.id}{config.prompt.signature.id}" + program_name = f"{config.prompt.type.id}{config.prompt.program.id}" + + log.info("Starting training: %s(%s) program with %s optimizer", program_name, signature_name, config.optimizer.name) + + dataloader = DATALOADERS[config.prompt.type.id](config) + metric = METRICS[config.prompt.type.id] + signature = SIGNATURES[signature_name] + program = PROGRAMS[program_name](signature) + + dataset = await dataloader.load() + + lm = dspy.__dict__[config.llm.provider](model=config.llm.model_name) + dspy.settings.configure(lm=lm) + + optimizer = dspy.teleprompt.__dict__[config.optimizer.name](metric=metric, **config.optimizer.params) + compiled_program = optimizer.compile( + student=program, + trainset=dataset, + eval_kwargs={ + "num_threads": config.num_threads, + "display_progress": True, + }, + **(config.optimizer.compile or {}), + ) + + log.info("Training finished. Saving compiled program...") + + output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) + program_file = output_dir / f"{program.__class__.__name__}Optimized.json" + compiled_program.save(program_file) + + log.info("Compiled program saved under directory: %s", output_dir) + + +@hydra.main(config_path="config", config_name="train", version_base="3.2") +def main(config: DictConfig) -> None: + """ + Function running evaluation for all datasets and evaluation tasks defined in hydra config. + + Args: + config: Hydra configuration. + """ + asyncio.run(train(config)) + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/extra/prompt_tuning/tuning/__init__.py b/extra/prompt_tuning/tuning/__init__.py index e69de29b..f41fd945 100644 --- a/extra/prompt_tuning/tuning/__init__.py +++ b/extra/prompt_tuning/tuning/__init__.py @@ -0,0 +1,24 @@ +from enum import Enum + +from .loaders import IQLGenerationDataLoader +from .metrics import aggregation_assess_acc, filtering_assess_acc + + +class ProgramType(Enum): + """ + Program types. + """ + + FILTERING_ASSESSOR = "FilteringAssessor" + AGGREGATION_ASSESSOR = "AggregationAssessor" + + +DATALOADERS = { + ProgramType.FILTERING_ASSESSOR.value: IQLGenerationDataLoader, + ProgramType.AGGREGATION_ASSESSOR.value: IQLGenerationDataLoader, +} + +METRICS = { + ProgramType.FILTERING_ASSESSOR.value: filtering_assess_acc, + ProgramType.AGGREGATION_ASSESSOR.value: aggregation_assess_acc, +} diff --git a/extra/prompt_tuning/tuning/loaders.py b/extra/prompt_tuning/tuning/loaders.py index 2cc7dc96..9553643c 100644 --- a/extra/prompt_tuning/tuning/loaders.py +++ b/extra/prompt_tuning/tuning/loaders.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod -from typing import Dict, Iterable, List +from typing import Iterable, List import dspy.datasets from dspy import Example +from omegaconf import DictConfig class DataLoader(ABC): @@ -10,7 +11,7 @@ class DataLoader(ABC): Data loader. """ - def __init__(self, config: Dict) -> None: + def __init__(self, config: DictConfig) -> None: self.config = config @abstractmethod diff --git a/extra/prompt_tuning/tuning/metrics/__init__.py b/extra/prompt_tuning/tuning/metrics/__init__.py index 56615738..60597b40 100644 --- a/extra/prompt_tuning/tuning/metrics/__init__.py +++ b/extra/prompt_tuning/tuning/metrics/__init__.py @@ -1,3 +1,3 @@ -from .iql import filtering_assess_acc +from .iql import aggregation_assess_acc, filtering_assess_acc -__all__ = ["filtering_assess_acc"] +__all__ = ["aggregation_assess_acc", "filtering_assess_acc"] diff --git a/extra/prompt_tuning/tuning/metrics/iql.py b/extra/prompt_tuning/tuning/metrics/iql.py index d47b0689..b2ad6a42 100644 --- a/extra/prompt_tuning/tuning/metrics/iql.py +++ b/extra/prompt_tuning/tuning/metrics/iql.py @@ -1,19 +1,35 @@ -from typing import Dict +from typing import Dict, List, Optional from dspy import Prediction -def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool: +def filtering_assess_acc(gold: Dict, pred: Prediction, _trace: Optional[List] = None) -> bool: """ - IQL decision metric. + IQL filtering decision metric. Args: gold: The ground truth data point. pred: The prediction. Returns: - The decision metric. + The filtering decision accuracy. """ return ((gold["iql_filters"] is None and not gold["iql_filters_unsupported"]) and not pred.decision) or ( (gold["iql_filters"] is not None or gold["iql_filters_unsupported"]) and pred.decision ) + + +def aggregation_assess_acc(gold: Dict, pred: Prediction, _trace: Optional[List] = None) -> bool: + """ + IQL aggregation decision metric. + + Args: + gold: The ground truth data point. + pred: The prediction. + + Returns: + The aggregation decision accuracy. + """ + return ((gold["iql_aggregation"] is None and not gold["iql_aggregation_unsupported"]) and not pred.decision) or ( + (gold["iql_aggregation"] is not None or gold["iql_aggregation_unsupported"]) and pred.decision + ) diff --git a/extra/prompt_tuning/tuning/programs/__init__.py b/extra/prompt_tuning/tuning/programs/__init__.py index 1961d77d..54038fbf 100644 --- a/extra/prompt_tuning/tuning/programs/__init__.py +++ b/extra/prompt_tuning/tuning/programs/__init__.py @@ -1,8 +1,27 @@ -from .iql import FilteringAssessorBaseline, FilteringAssessorCoT +from .iql import ( + AggregationAssessorCoT, + AggregationAssessorCoTH, + AggregationAssessorPredict, + FilteringAssessorCoT, + FilteringAssessorCoTH, + FilteringAssessorPredict, +) PROGRAMS = { - FilteringAssessorBaseline.__name__: FilteringAssessorBaseline, + FilteringAssessorPredict.__name__: FilteringAssessorPredict, FilteringAssessorCoT.__name__: FilteringAssessorCoT, + FilteringAssessorCoTH.__name__: FilteringAssessorCoTH, + AggregationAssessorPredict.__name__: AggregationAssessorPredict, + AggregationAssessorCoT.__name__: AggregationAssessorCoT, + AggregationAssessorCoTH.__name__: AggregationAssessorCoTH, } -__all__ = ["PROGRAMS", "FilteringAssessorBaseline", "FilteringAssessorCoT"] +__all__ = [ + "PROGRAMS", + "AggregationAssessorPredict", + "AggregationAssessorCoT", + "FilteringAssessorPredict", + "FilteringAssessorCoT", + "FilteringAssessorCoTH", + "AggregationAssessorCoTH", +] diff --git a/extra/prompt_tuning/tuning/programs/iql.py b/extra/prompt_tuning/tuning/programs/iql.py index 2a20da47..f33c444e 100644 --- a/extra/prompt_tuning/tuning/programs/iql.py +++ b/extra/prompt_tuning/tuning/programs/iql.py @@ -1,16 +1,18 @@ -from dspy import ChainOfThought, Module, Predict, Prediction +from typing import Type -from ..signatures.iql import CheckQuestionFiltering +from dspy import ChainOfThought, ChainOfThoughtWithHint, Module, Predict, Prediction +from ..signatures.iql import AggregationAssessor, FilteringAssessor -class FilteringAssessorBaseline(Module): + +class FilteringAssessorPredict(Module): """ Program that assesses whether a question requires filtering. """ - def __init__(self) -> None: + def __init__(self, signature: Type[FilteringAssessor]) -> None: super().__init__() - self.decide = Predict(CheckQuestionFiltering) + self.decide = Predict(signature) def forward(self, question: str) -> Prediction: """ @@ -31,9 +33,9 @@ class FilteringAssessorCoT(Module): Program that assesses whether a question requires filtering. """ - def __init__(self) -> None: + def __init__(self, signature: Type[FilteringAssessor]) -> None: super().__init__() - self.decide = ChainOfThought(CheckQuestionFiltering) + self.decide = ChainOfThought(signature) def forward(self, question: str) -> Prediction: """ @@ -47,3 +49,101 @@ def forward(self, question: str) -> Prediction: """ decision = self.decide(question=question).decision return Prediction(decision=decision.lower() == "true") + + +class FilteringAssessorCoTH(Module): + """ + Program that assesses whether a question requires filtering. + """ + + def __init__(self, signature: Type[FilteringAssessor]) -> None: + super().__init__() + self.decide = ChainOfThoughtWithHint(signature) + + def forward(self, question: str) -> Prediction: + """ + Assess whether a question requires filtering. + + Args: + question: The question to assess. + + Returns: + The prediction. + """ + decision = self.decide( + question=question, + hint="Look for words indicating data specific features.", + ).decision + return Prediction(decision=decision.lower() == "true") + + +class AggregationAssessorPredict(Module): + """ + Program that assesses whether a question requires aggregation. + """ + + def __init__(self, signature: Type[AggregationAssessor]) -> None: + super().__init__() + self.decide = Predict(signature) + + def forward(self, question: str) -> Prediction: + """ + Assess whether a question requires aggregation. + + Args: + question: The question to assess. + + Returns: + The prediction. + """ + decision = self.decide(question=question).decision + return Prediction(decision=decision.lower() == "true") + + +class AggregationAssessorCoT(Module): + """ + Program that assesses whether a question requires aggregation. + """ + + def __init__(self, signature: Type[AggregationAssessor]) -> None: + super().__init__() + self.decide = ChainOfThought(signature) + + def forward(self, question: str) -> Prediction: + """ + Assess whether a question requires aggregation. + + Args: + question: The question to assess. + + Returns: + The prediction. + """ + decision = self.decide(question=question).decision + return Prediction(decision=decision.lower() == "true") + + +class AggregationAssessorCoTH(Module): + """ + Program that assesses whether a question requires aggregation. + """ + + def __init__(self, signature: Type[AggregationAssessor]) -> None: + super().__init__() + self.decide = ChainOfThoughtWithHint(signature) + + def forward(self, question: str) -> Prediction: + """ + Assess whether a question requires aggregation. + + Args: + question: The question to assess. + + Returns: + The prediction. + """ + decision = self.decide( + question=question, + hint="Look for words indicating aggregation functions.", + ).decision + return Prediction(decision=decision.lower() == "true") diff --git a/extra/prompt_tuning/tuning/signatures/__init__.py b/extra/prompt_tuning/tuning/signatures/__init__.py index dd3be583..a7c50ad1 100644 --- a/extra/prompt_tuning/tuning/signatures/__init__.py +++ b/extra/prompt_tuning/tuning/signatures/__init__.py @@ -1,3 +1,26 @@ -from .iql import CheckQuestionFiltering +from .iql import ( + AggregationAssessor, + AggregationAssessorBaseline, + AggregationAssessorOptimized, + FilteringAssessor, + FilteringAssessorBaseline, + FilteringAssessorOptimized, +) -__all__ = ["CheckQuestionFiltering"] +SIGNATURES = { + AggregationAssessorBaseline.__name__: AggregationAssessorBaseline, + AggregationAssessorOptimized.__name__: AggregationAssessorOptimized, + FilteringAssessorBaseline.__name__: FilteringAssessorBaseline, + FilteringAssessorOptimized.__name__: FilteringAssessorOptimized, +} + + +__all__ = [ + "AggregationAssessor", + "FilteringAssessor", + "AggregationAssessorBaseline", + "AggregationAssessorOptimized", + "FilteringAssessorBaseline", + "FilteringAssessorOptimized", + "SIGNATURES", +] diff --git a/extra/prompt_tuning/tuning/signatures/iql.py b/extra/prompt_tuning/tuning/signatures/iql.py index 273edf60..0305ce81 100644 --- a/extra/prompt_tuning/tuning/signatures/iql.py +++ b/extra/prompt_tuning/tuning/signatures/iql.py @@ -1,20 +1,66 @@ +from abc import ABC + from dspy import InputField, OutputField, Signature -class CheckQuestionFiltering(Signature): +class FilteringAssessor(Signature, ABC): + """ + Abstract signature, should be implemented by a concrete signature, + describing the actual task for the LLM. + """ + + question = InputField( + prefix="Question: ", + ) + decision = OutputField( + prefix="Decision: ", + desc=("indicates whether the answer to the question requires data filtering. " "(Respond with True or False)"), + ) + + +class FilteringAssessorBaseline(FilteringAssessor): """ - Given a question, determine whether the answer requires initial data filtering in order to compute it. - Initial data filtering is a process in which the result set is reduced to only include the rows that + Given a question, determine whether the answer requires data filtering in order to compute it. + Data filtering is a process in which the result set is reduced to only include the rows that meet certain criteria specified in the question. """ + +class FilteringAssessorOptimized(FilteringAssessor): + """ + Given a question, determine whether the answer requires data filtering in order to compute it. + Data filtering is a process in which the result set is filtered based on the specific features + stated in the question. Such a question can be easily identified by using words that refer to + specific feature values (rather than feature names). Look for words indicating specific values + that the answer should contain. + """ + + +class AggregationAssessor(Signature, ABC): + """ + Abstract signature, should be implemented by a concrete signature, + describing the actual task for the LLM. + """ + question = InputField( prefix="Question: ", ) decision = OutputField( prefix="Decision: ", - desc=( - "indicates whether the answer to the question requires initial data filtering. " - "(Respond with True or False)" - ), + desc=("indicates whether the answer to the question requires data aggregation. (Respond with True or False)"), ) + + +class AggregationAssessorBaseline(AggregationAssessor): + """ + Given a question, determine whether the answer requires data aggregation in order to compute it. + Data aggregation is a process in which we calculate a single values for a group of rows in the result set. + """ + + +class AggregationAssessorOptimized(AggregationAssessor): + """ + Given a question, determine whether the answer requires data aggregation in order to compute it. + Data aggregation is a process in which we calculate a single values for a group of rows in the result set. + Most common aggregation functions are counting, averaging, summing, but other types of aggregation are possible. + """