Skip to content

Commit

Permalink
improve aa signature
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Aug 30, 2024
1 parent 3386d56 commit fa64aee
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
3 changes: 3 additions & 0 deletions extra/prompt_tuning/tuning/programs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .iql import (
AggregationAssessorCoT,
AggregationAssessorCoTH,
AggregationAssessorPredict,
FilteringAssessorCoT,
FilteringAssessorCoTH,
Expand All @@ -12,6 +13,7 @@
FilteringAssessorCoTH.__name__: FilteringAssessorCoTH,
AggregationAssessorPredict.__name__: AggregationAssessorPredict,
AggregationAssessorCoT.__name__: AggregationAssessorCoT,
AggregationAssessorCoTH.__name__: AggregationAssessorCoTH,
}

__all__ = [
Expand All @@ -21,4 +23,5 @@
"FilteringAssessorPredict",
"FilteringAssessorCoT",
"FilteringAssessorCoTH",
"AggregationAssessorCoTH",
]
26 changes: 26 additions & 0 deletions extra/prompt_tuning/tuning/programs/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,29 @@ def forward(self, question: str) -> Prediction:
"""
decision = self.decide(question=question).decision
return Prediction(decision=decision.lower() == "true")


class AggregationAssessorCoTH(Module):
"""
Program that assesses whether a question requires aggregation.
"""

def __init__(self, signature: Type[AggregationAssessor]) -> None:
super().__init__()
self.decide = ChainOfThoughtWithHint(signature)

def forward(self, question: str) -> Prediction:
"""
Assess whether a question requires aggregation.
Args:
question: The question to assess.
Returns:
The prediction.
"""
decision = self.decide(
question=question,
hint="Look for words indicating aggregation functions.",
).decision
return Prediction(decision=decision.lower() == "true")
10 changes: 3 additions & 7 deletions extra/prompt_tuning/tuning/signatures/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ class AggregationAssessorBaseline(AggregationAssessor):

class AggregationAssessorOptimized(AggregationAssessor):
"""
Look at the dependencies between the elements in the question and distinguish whether a single value can be obtained
for a groupof entities in the data table by aggregating necessary values.
Given a question, determine whether the answer requires data aggregation in order to compute it.
Data aggregation is a process in which we calculate a single values for a group of rows in the result set.
Most common aggregation functions are counting, averaging, summing, but other types of aggregation are possible.
"""

decision = OutputField(
prefix="Instructions to identify aggregated computations given a question, analyze dependencies -> ",
desc="indicates whether the answer to the question requires data aggregation. (Respond with True or False)",
)

0 comments on commit fa64aee

Please sign in to comment.