Skip to content

Commit

Permalink
improve fa signature
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Aug 30, 2024
1 parent 42edc8e commit 3386d56
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 11 deletions.
6 changes: 3 additions & 3 deletions extra/prompt_tuning/config/optimizer/copro.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: COPRO
params:
breadth: 3
depth: 10
init_temperature: 1.4
breadth: 4
depth: 15
init_temperature: 1.5
compile:
1 change: 1 addition & 0 deletions extra/prompt_tuning/config/prompt/program/coth.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
id: CoTH
5 changes: 3 additions & 2 deletions extra/prompt_tuning/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ async def evaluate(config: DictConfig) -> None:
run = neptune.init_run()
run["sys/tags"].add(
[
config.program.type,
config.program.name,
config.prompt.type.id,
config.prompt.signature.id,
config.prompt.program.id,
*config.data.db_ids,
*config.data.difficulties,
]
Expand Down
10 changes: 9 additions & 1 deletion extra/prompt_tuning/tuning/programs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from .iql import AggregationAssessorCoT, AggregationAssessorPredict, FilteringAssessorCoT, FilteringAssessorPredict
from .iql import (
AggregationAssessorCoT,
AggregationAssessorPredict,
FilteringAssessorCoT,
FilteringAssessorCoTH,
FilteringAssessorPredict,
)

PROGRAMS = {
FilteringAssessorPredict.__name__: FilteringAssessorPredict,
FilteringAssessorCoT.__name__: FilteringAssessorCoT,
FilteringAssessorCoTH.__name__: FilteringAssessorCoTH,
AggregationAssessorPredict.__name__: AggregationAssessorPredict,
AggregationAssessorCoT.__name__: AggregationAssessorCoT,
}
Expand All @@ -13,4 +20,5 @@
"AggregationAssessorCoT",
"FilteringAssessorPredict",
"FilteringAssessorCoT",
"FilteringAssessorCoTH",
]
28 changes: 27 additions & 1 deletion extra/prompt_tuning/tuning/programs/iql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Type

from dspy import ChainOfThought, Module, Predict, Prediction
from dspy import ChainOfThought, ChainOfThoughtWithHint, Module, Predict, Prediction

from ..signatures.iql import AggregationAssessor, FilteringAssessor

Expand Down Expand Up @@ -51,6 +51,32 @@ def forward(self, question: str) -> Prediction:
return Prediction(decision=decision.lower() == "true")


class FilteringAssessorCoTH(Module):
"""
Program that assesses whether a question requires filtering.
"""

def __init__(self, signature: Type[FilteringAssessor]) -> None:
super().__init__()
self.decide = ChainOfThoughtWithHint(signature)

def forward(self, question: str) -> Prediction:
"""
Assess whether a question requires filtering.
Args:
question: The question to assess.
Returns:
The prediction.
"""
decision = self.decide(
question=question,
hint="Look for words indicating data specific features.",
).decision
return Prediction(decision=decision.lower() == "true")


class AggregationAssessorPredict(Module):
"""
Program that assesses whether a question requires aggregation.
Expand Down
13 changes: 9 additions & 4 deletions extra/prompt_tuning/tuning/signatures/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class FilteringAssessorBaseline(FilteringAssessor):
class FilteringAssessorOptimized(FilteringAssessor):
"""
Given a question, determine whether the answer requires initial data filtering in order to compute it.
Initial data filtering is a process in which the result set is reduced to only include the rows that
meet certain criteria specified in the question.
Initial data filtering is a process in which the result set is filtered based on the specific features
stated in the question.
"""


Expand Down Expand Up @@ -61,6 +61,11 @@ class AggregationAssessorBaseline(AggregationAssessor):

class AggregationAssessorOptimized(AggregationAssessor):
"""
Given a question, determine whether the answer requires data aggregation in order to compute it.
Data aggregation is a process in which we calculate a single values for a group of rows in the result set.
Look at the dependencies between the elements in the question and distinguish whether a single value can be obtained
for a groupof entities in the data table by aggregating necessary values.
"""

decision = OutputField(
prefix="Instructions to identify aggregated computations given a question, analyze dependencies -> ",
desc="indicates whether the answer to the question requires data aggregation. (Respond with True or False)",
)

0 comments on commit 3386d56

Please sign in to comment.