From fa64aeeb29749192d894317aa55ddba9f6fe5ce8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 30 Aug 2024 20:23:06 +0200 Subject: [PATCH] improve aa signature --- .../prompt_tuning/tuning/programs/__init__.py | 3 +++ extra/prompt_tuning/tuning/programs/iql.py | 26 +++++++++++++++++++ extra/prompt_tuning/tuning/signatures/iql.py | 10 +++---- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/extra/prompt_tuning/tuning/programs/__init__.py b/extra/prompt_tuning/tuning/programs/__init__.py index 221900cf..54038fbf 100644 --- a/extra/prompt_tuning/tuning/programs/__init__.py +++ b/extra/prompt_tuning/tuning/programs/__init__.py @@ -1,5 +1,6 @@ from .iql import ( AggregationAssessorCoT, + AggregationAssessorCoTH, AggregationAssessorPredict, FilteringAssessorCoT, FilteringAssessorCoTH, @@ -12,6 +13,7 @@ FilteringAssessorCoTH.__name__: FilteringAssessorCoTH, AggregationAssessorPredict.__name__: AggregationAssessorPredict, AggregationAssessorCoT.__name__: AggregationAssessorCoT, + AggregationAssessorCoTH.__name__: AggregationAssessorCoTH, } __all__ = [ @@ -21,4 +23,5 @@ "FilteringAssessorPredict", "FilteringAssessorCoT", "FilteringAssessorCoTH", + "AggregationAssessorCoTH", ] diff --git a/extra/prompt_tuning/tuning/programs/iql.py b/extra/prompt_tuning/tuning/programs/iql.py index 15aaf13e..f33c444e 100644 --- a/extra/prompt_tuning/tuning/programs/iql.py +++ b/extra/prompt_tuning/tuning/programs/iql.py @@ -121,3 +121,29 @@ def forward(self, question: str) -> Prediction: """ decision = self.decide(question=question).decision return Prediction(decision=decision.lower() == "true") + + +class AggregationAssessorCoTH(Module): + """ + Program that assesses whether a question requires aggregation. + """ + + def __init__(self, signature: Type[AggregationAssessor]) -> None: + super().__init__() + self.decide = ChainOfThoughtWithHint(signature) + + def forward(self, question: str) -> Prediction: + """ + Assess whether a question requires aggregation. + + Args: + question: The question to assess. + + Returns: + The prediction. + """ + decision = self.decide( + question=question, + hint="Look for words indicating aggregation functions.", + ).decision + return Prediction(decision=decision.lower() == "true") diff --git a/extra/prompt_tuning/tuning/signatures/iql.py b/extra/prompt_tuning/tuning/signatures/iql.py index 73b44a94..ea113514 100644 --- a/extra/prompt_tuning/tuning/signatures/iql.py +++ b/extra/prompt_tuning/tuning/signatures/iql.py @@ -61,11 +61,7 @@ class AggregationAssessorBaseline(AggregationAssessor): class AggregationAssessorOptimized(AggregationAssessor): """ - Look at the dependencies between the elements in the question and distinguish whether a single value can be obtained - for a groupof entities in the data table by aggregating necessary values. + Given a question, determine whether the answer requires data aggregation in order to compute it. + Data aggregation is a process in which we calculate a single values for a group of rows in the result set. + Most common aggregation functions are counting, averaging, summing, but other types of aggregation are possible. """ - - decision = OutputField( - prefix="Instructions to identify aggregated computations given a question, analyze dependencies -> ", - desc="indicates whether the answer to the question requires data aggregation. (Respond with True or False)", - )