diff --git a/extra/prompt_tuning/config/program/aggregation-assessor-baseline.yaml b/extra/prompt_tuning/config/program/aggregation-assessor-baseline.yaml new file mode 100644 index 00000000..98d1d68b --- /dev/null +++ b/extra/prompt_tuning/config/program/aggregation-assessor-baseline.yaml @@ -0,0 +1,2 @@ +type: AGGREGATION_ASSESSOR +name: AggregationAssessorBaseline diff --git a/extra/prompt_tuning/config/program/aggregation-assessor-cot.yaml b/extra/prompt_tuning/config/program/aggregation-assessor-cot.yaml new file mode 100644 index 00000000..f6a7be97 --- /dev/null +++ b/extra/prompt_tuning/config/program/aggregation-assessor-cot.yaml @@ -0,0 +1,2 @@ +type: AGGREGATION_ASSESSOR +name: AggregationAssessorCoT diff --git a/extra/prompt_tuning/evaluate.py b/extra/prompt_tuning/evaluate.py index 35bcf2e8..0ae140e4 100644 --- a/extra/prompt_tuning/evaluate.py +++ b/extra/prompt_tuning/evaluate.py @@ -10,7 +10,7 @@ from neptune.utils import stringify_unsupported from omegaconf import DictConfig from tuning.loaders import IQLGenerationDataLoader -from tuning.metrics import filtering_assess_acc +from tuning.metrics import aggregation_assess_acc, filtering_assess_acc from tuning.programs import PROGRAMS from tuning.utils import save, serialize_results @@ -25,14 +25,17 @@ class EvaluationType(Enum): """ FILTERING_ASSESSOR = "FILTERING_ASSESSOR" + AGGREGATION_ASSESSOR = "AGGREGATION_ASSESSOR" EVALUATION_DATALOADERS = { EvaluationType.FILTERING_ASSESSOR.value: IQLGenerationDataLoader, + EvaluationType.AGGREGATION_ASSESSOR.value: IQLGenerationDataLoader, } EVALUATION_METRICS = { EvaluationType.FILTERING_ASSESSOR.value: filtering_assess_acc, + EvaluationType.AGGREGATION_ASSESSOR.value: aggregation_assess_acc, } diff --git a/extra/prompt_tuning/tuning/metrics/__init__.py b/extra/prompt_tuning/tuning/metrics/__init__.py index 56615738..60597b40 100644 --- a/extra/prompt_tuning/tuning/metrics/__init__.py +++ b/extra/prompt_tuning/tuning/metrics/__init__.py @@ -1,3 +1,3 @@ -from .iql import filtering_assess_acc +from .iql import aggregation_assess_acc, filtering_assess_acc -__all__ = ["filtering_assess_acc"] +__all__ = ["aggregation_assess_acc", "filtering_assess_acc"] diff --git a/extra/prompt_tuning/tuning/metrics/iql.py b/extra/prompt_tuning/tuning/metrics/iql.py index d47b0689..3340f929 100644 --- a/extra/prompt_tuning/tuning/metrics/iql.py +++ b/extra/prompt_tuning/tuning/metrics/iql.py @@ -5,15 +5,31 @@ def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool: """ - IQL decision metric. + IQL filtering decision metric. Args: gold: The ground truth data point. pred: The prediction. Returns: - The decision metric. + The filtering decision accuracy. """ return ((gold["iql_filters"] is None and not gold["iql_filters_unsupported"]) and not pred.decision) or ( (gold["iql_filters"] is not None or gold["iql_filters_unsupported"]) and pred.decision ) + + +def aggregation_assess_acc(gold: Dict, pred: Prediction) -> bool: + """ + IQL aggregation decision metric. + + Args: + gold: The ground truth data point. + pred: The prediction. + + Returns: + The aggregation decision accuracy. + """ + return ((gold["iql_aggregation"] is None and not gold["iql_aggregation_unsupported"]) and not pred.decision) or ( + (gold["iql_aggregation"] is not None or gold["iql_aggregation_unsupported"]) and pred.decision + ) diff --git a/extra/prompt_tuning/tuning/programs/iql.py b/extra/prompt_tuning/tuning/programs/iql.py index 2a20da47..a7209672 100644 --- a/extra/prompt_tuning/tuning/programs/iql.py +++ b/extra/prompt_tuning/tuning/programs/iql.py @@ -1,6 +1,6 @@ from dspy import ChainOfThought, Module, Predict, Prediction -from ..signatures.iql import CheckQuestionFiltering +from ..signatures.iql import CheckQuestionAggregation, CheckQuestionFiltering class FilteringAssessorBaseline(Module): @@ -47,3 +47,49 @@ def forward(self, question: str) -> Prediction: """ decision = self.decide(question=question).decision return Prediction(decision=decision.lower() == "true") + + +class AggregationAssessorBaseline(Module): + """ + Program that assesses whether a question requires aggregation. + """ + + def __init__(self) -> None: + super().__init__() + self.decide = Predict(CheckQuestionAggregation) + + def forward(self, question: str) -> Prediction: + """ + Assess whether a question requires aggregation. + + Args: + question: The question to assess. + + Returns: + The prediction. + """ + decision = self.decide(question=question).decision + return Prediction(decision=decision.lower() == "true") + + +class AggregationAssessorCoT(Module): + """ + Program that assesses whether a question requires aggregation. + """ + + def __init__(self) -> None: + super().__init__() + self.decide = ChainOfThought(CheckQuestionAggregation) + + def forward(self, question: str) -> Prediction: + """ + Assess whether a question requires aggregation. + + Args: + question: The question to assess. + + Returns: + The prediction. + """ + decision = self.decide(question=question).decision + return Prediction(decision=decision.lower() == "true") diff --git a/extra/prompt_tuning/tuning/signatures/iql.py b/extra/prompt_tuning/tuning/signatures/iql.py index 273edf60..d3a6c531 100644 --- a/extra/prompt_tuning/tuning/signatures/iql.py +++ b/extra/prompt_tuning/tuning/signatures/iql.py @@ -18,3 +18,22 @@ class CheckQuestionFiltering(Signature): "(Respond with True or False)" ), ) + + +class CheckQuestionAggregation(Signature): + """ + Given a question, determine whether the answer requires initial data filtering in order to compute it. + Initial data filtering is a process in which the result set is reduced to only include the rows that + meet certain criteria specified in the question. + """ + + question = InputField( + prefix="Question: ", + ) + decision = OutputField( + prefix="Decision: ", + desc=( + "indicates whether the answer to the question requires initial data filtering. " + "(Respond with True or False)" + ), + )