Skip to content

Commit

Permalink
add aggregation assessor program
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Aug 17, 2024
1 parent 2714e7c commit b202ea3
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
type: AGGREGATION_ASSESSOR
name: AggregationAssessorBaseline
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
type: AGGREGATION_ASSESSOR
name: AggregationAssessorCoT
5 changes: 4 additions & 1 deletion extra/prompt_tuning/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from neptune.utils import stringify_unsupported
from omegaconf import DictConfig
from tuning.loaders import IQLGenerationDataLoader
from tuning.metrics import filtering_assess_acc
from tuning.metrics import aggregation_assess_acc, filtering_assess_acc
from tuning.programs import PROGRAMS
from tuning.utils import save, serialize_results

Expand All @@ -25,14 +25,17 @@ class EvaluationType(Enum):
"""

FILTERING_ASSESSOR = "FILTERING_ASSESSOR"
AGGREGATION_ASSESSOR = "AGGREGATION_ASSESSOR"


EVALUATION_DATALOADERS = {
EvaluationType.FILTERING_ASSESSOR.value: IQLGenerationDataLoader,
EvaluationType.AGGREGATION_ASSESSOR.value: IQLGenerationDataLoader,
}

EVALUATION_METRICS = {
EvaluationType.FILTERING_ASSESSOR.value: filtering_assess_acc,
EvaluationType.AGGREGATION_ASSESSOR.value: aggregation_assess_acc,
}


Expand Down
4 changes: 2 additions & 2 deletions extra/prompt_tuning/tuning/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .iql import filtering_assess_acc
from .iql import aggregation_assess_acc, filtering_assess_acc

__all__ = ["filtering_assess_acc"]
__all__ = ["aggregation_assess_acc", "filtering_assess_acc"]
20 changes: 18 additions & 2 deletions extra/prompt_tuning/tuning/metrics/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,31 @@

def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool:
"""
IQL decision metric.
IQL filtering decision metric.
Args:
gold: The ground truth data point.
pred: The prediction.
Returns:
The decision metric.
The filtering decision accuracy.
"""
return ((gold["iql_filters"] is None and not gold["iql_filters_unsupported"]) and not pred.decision) or (
(gold["iql_filters"] is not None or gold["iql_filters_unsupported"]) and pred.decision
)


def aggregation_assess_acc(gold: Dict, pred: Prediction) -> bool:
"""
IQL aggregation decision metric.
Args:
gold: The ground truth data point.
pred: The prediction.
Returns:
The aggregation decision accuracy.
"""
return ((gold["iql_aggregation"] is None and not gold["iql_aggregation_unsupported"]) and not pred.decision) or (
(gold["iql_aggregation"] is not None or gold["iql_aggregation_unsupported"]) and pred.decision
)
48 changes: 47 additions & 1 deletion extra/prompt_tuning/tuning/programs/iql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dspy import ChainOfThought, Module, Predict, Prediction

from ..signatures.iql import CheckQuestionFiltering
from ..signatures.iql import CheckQuestionAggregation, CheckQuestionFiltering


class FilteringAssessorBaseline(Module):
Expand Down Expand Up @@ -47,3 +47,49 @@ def forward(self, question: str) -> Prediction:
"""
decision = self.decide(question=question).decision
return Prediction(decision=decision.lower() == "true")


class AggregationAssessorBaseline(Module):
"""
Program that assesses whether a question requires aggregation.
"""

def __init__(self) -> None:
super().__init__()
self.decide = Predict(CheckQuestionAggregation)

def forward(self, question: str) -> Prediction:
"""
Assess whether a question requires aggregation.
Args:
question: The question to assess.
Returns:
The prediction.
"""
decision = self.decide(question=question).decision
return Prediction(decision=decision.lower() == "true")


class AggregationAssessorCoT(Module):
"""
Program that assesses whether a question requires aggregation.
"""

def __init__(self) -> None:
super().__init__()
self.decide = ChainOfThought(CheckQuestionAggregation)

def forward(self, question: str) -> Prediction:
"""
Assess whether a question requires aggregation.
Args:
question: The question to assess.
Returns:
The prediction.
"""
decision = self.decide(question=question).decision
return Prediction(decision=decision.lower() == "true")
19 changes: 19 additions & 0 deletions extra/prompt_tuning/tuning/signatures/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,22 @@ class CheckQuestionFiltering(Signature):
"(Respond with True or False)"
),
)


class CheckQuestionAggregation(Signature):
"""
Given a question, determine whether the answer requires initial data filtering in order to compute it.
Initial data filtering is a process in which the result set is reduced to only include the rows that
meet certain criteria specified in the question.
"""

question = InputField(
prefix="Question: ",
)
decision = OutputField(
prefix="Decision: ",
desc=(
"indicates whether the answer to the question requires initial data filtering. "
"(Respond with True or False)"
),
)

0 comments on commit b202ea3

Please sign in to comment.