Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Aug 30, 2024
1 parent 111fe8c commit dde4452
Show file tree
Hide file tree
Showing 21 changed files with 137 additions and 53 deletions.
28 changes: 20 additions & 8 deletions extra/prompt_tuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,25 @@ All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.i

### Train new prompts

Tune `filtering-assessor-baseline` prompt using [COPRO](https://dspy-docs.vercel.app/docs/deep-dive/teleprompter/signature-optimizer#how-copro-works) optimizer on the `superhero` database with `gpt-3.5-turbo`:
Tune `filtering-assessor` prompt on base signature using [COPRO](https://dspy-docs.vercel.app/docs/deep-dive/teleprompter/signature-optimizer#how-copro-works) optimizer on the `superhero` database with `gpt-3.5-turbo`:

```bash
python train.py program=filtering-assessor-baseline
python train.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict
```

Train multiple prompts:

```bash
python train.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot
python train.py --multirun \
prompt/type=filtering-assessor \
prompt/signature=baseline \
prompt/program=predict,cot
```

Tweak optimizer params to get different results:

```bash
python train.py \
program=filtering-assessor-baseline \
optimizer.params.breadth=2 \
optimizer.params.depth=3 \
optimizer.params.init_temperature=1.0
Expand All @@ -38,23 +40,33 @@ python train.py \
Run evalution of filtering assessor baseline on the `superhero` database with `gpt-3.5-turbo`:

```bash
python evaluate.py program=filtering-assessor-baseline
python evaluate.py prompt/type=filtering-assessor prompt/signature=baseline prompt/program=predict
```

Test multiple prompts:

```bash
python evaluate.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot
python evaluate.py --multirun \
prompt/type=filtering-assessor \
prompt/signature=baseline \
prompt/program=predict,cot
```

```bash
python evaluate.py --multirun program=aggregation-assessor-baseline,aggregation-assessor-cot
python evaluate.py --multirun \
prompt/type=aggregation-assessor \
prompt/signature=baseline \
prompt/program=predict,cot
```

Compare prompt performance on multiple LLMs:

```bash
python evaluate.py --multirun program=filtering-assessor-baseline llm=gpt-3.5-turbo,claude-3.5-sonnet
python evaluate.py --multirun \
prompt/type=filtering-assessor \
prompt/signature=baseline \
prompt/program=predict \
llm=gpt-3.5-turbo,claude-3.5-sonnet
```

#### Log to Neptune
Expand Down
2 changes: 1 addition & 1 deletion extra/prompt_tuning/config/evaluate.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- data: superhero
- llm: gpt-3.5-turbo
- program: filtering-assessor-baseline
- prompt: prompt
- _self_

num_threads: 32
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

1 change: 1 addition & 0 deletions extra/prompt_tuning/config/prompt/program/cot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
id: CoT
1 change: 1 addition & 0 deletions extra/prompt_tuning/config/prompt/program/predict.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
id: Predict
8 changes: 8 additions & 0 deletions extra/prompt_tuning/config/prompt/prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defaults:
- type: filtering-assessor
- signature: baseline
- program: predict
- _self_

num_threads: 32
neptune: False
1 change: 1 addition & 0 deletions extra/prompt_tuning/config/prompt/signature/baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
id: Baseline
1 change: 1 addition & 0 deletions extra/prompt_tuning/config/prompt/signature/optimized.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
id: Optimized
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
id: AggregationAssessor
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
id: FilteringAssessor
2 changes: 1 addition & 1 deletion extra/prompt_tuning/config/train.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- data: superhero
- llm: gpt-3.5-turbo
- program: filtering-assessor-baseline
- prompt: prompt
- optimizer: copro
- _self_

Expand Down
13 changes: 9 additions & 4 deletions extra/prompt_tuning/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from omegaconf import DictConfig
from tuning import DATALOADERS, METRICS
from tuning.programs import PROGRAMS
from tuning.signatures import SIGNATURES
from tuning.utils import save, serialize_results

logging.getLogger("httpx").setLevel(logging.ERROR)
Expand All @@ -24,11 +25,15 @@ async def evaluate(config: DictConfig) -> None:
Args:
config: Hydra configuration.
"""
log.info("Starting evaluation: %s", config.program.name)
signature_name = f"{config.prompt.type.id}{config.prompt.signature.id}"
program_name = f"{config.prompt.type.id}{config.prompt.program.id}"

dataloader = DATALOADERS[config.program.type](config)
metric = METRICS[config.program.type]
program = PROGRAMS[config.program.name]()
log.info("Starting evaluation: %s(%s) program", program_name, signature_name)

dataloader = DATALOADERS[config.prompt.type.id](config)
metric = METRICS[config.prompt.type.id]
signature = SIGNATURES[signature_name]
program = PROGRAMS[program_name](signature)

dataset = await dataloader.load()

Expand Down
13 changes: 9 additions & 4 deletions extra/prompt_tuning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from omegaconf import DictConfig
from tuning import DATALOADERS, METRICS
from tuning.programs import PROGRAMS
from tuning.signatures import SIGNATURES

logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("anthropic").setLevel(logging.ERROR)
Expand All @@ -21,11 +22,15 @@ async def train(config: DictConfig) -> None:
Args:
config: Hydra configuration.
"""
log.info("Starting training %s with %s", config.program.name, config.optimizer.name)
signature_name = f"{config.prompt.type.id}{config.prompt.signature.id}"
program_name = f"{config.prompt.type.id}{config.prompt.program.id}"

dataloader = DATALOADERS[config.program.type](config)
metric = METRICS[config.program.type]
program = PROGRAMS[config.program.name]()
log.info("Starting training: %s(%s) program with %s optimizer", program_name, signature_name, config.optimizer.name)

dataloader = DATALOADERS[config.prompt.type.id](config)
metric = METRICS[config.prompt.type.id]
signature = SIGNATURES[signature_name]
program = PROGRAMS[program_name](signature)

dataset = await dataloader.load()

Expand Down
4 changes: 2 additions & 2 deletions extra/prompt_tuning/tuning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class ProgramType(Enum):
Program types.
"""

FILTERING_ASSESSOR = "FILTERING_ASSESSOR"
AGGREGATION_ASSESSOR = "AGGREGATION_ASSESSOR"
FILTERING_ASSESSOR = "FilteringAssessor"
AGGREGATION_ASSESSOR = "AggregationAssessor"


DATALOADERS = {
Expand Down
10 changes: 5 additions & 5 deletions extra/prompt_tuning/tuning/programs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from .iql import AggregationAssessorBaseline, AggregationAssessorCoT, FilteringAssessorBaseline, FilteringAssessorCoT
from .iql import AggregationAssessorCoT, AggregationAssessorPredict, FilteringAssessorCoT, FilteringAssessorPredict

PROGRAMS = {
FilteringAssessorBaseline.__name__: FilteringAssessorBaseline,
FilteringAssessorPredict.__name__: FilteringAssessorPredict,
FilteringAssessorCoT.__name__: FilteringAssessorCoT,
AggregationAssessorBaseline.__name__: AggregationAssessorBaseline,
AggregationAssessorPredict.__name__: AggregationAssessorPredict,
AggregationAssessorCoT.__name__: AggregationAssessorCoT,
}

__all__ = [
"PROGRAMS",
"AggregationAssessorBaseline",
"AggregationAssessorPredict",
"AggregationAssessorCoT",
"FilteringAssessorBaseline",
"FilteringAssessorPredict",
"FilteringAssessorCoT",
]
24 changes: 13 additions & 11 deletions extra/prompt_tuning/tuning/programs/iql.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import Type

from dspy import ChainOfThought, Module, Predict, Prediction

from ..signatures.iql import CheckQuestionAggregation, CheckQuestionFiltering
from ..signatures.iql import AggregationAssessor, FilteringAssessor


class FilteringAssessorBaseline(Module):
class FilteringAssessorPredict(Module):
"""
Program that assesses whether a question requires filtering.
"""

def __init__(self) -> None:
def __init__(self, signature: Type[FilteringAssessor]) -> None:
super().__init__()
self.decide = Predict(CheckQuestionFiltering)
self.decide = Predict(signature)

def forward(self, question: str) -> Prediction:
"""
Expand All @@ -31,9 +33,9 @@ class FilteringAssessorCoT(Module):
Program that assesses whether a question requires filtering.
"""

def __init__(self) -> None:
def __init__(self, signature: Type[FilteringAssessor]) -> None:
super().__init__()
self.decide = ChainOfThought(CheckQuestionFiltering)
self.decide = ChainOfThought(signature)

def forward(self, question: str) -> Prediction:
"""
Expand All @@ -49,14 +51,14 @@ def forward(self, question: str) -> Prediction:
return Prediction(decision=decision.lower() == "true")


class AggregationAssessorBaseline(Module):
class AggregationAssessorPredict(Module):
"""
Program that assesses whether a question requires aggregation.
"""

def __init__(self) -> None:
def __init__(self, signature: Type[AggregationAssessor]) -> None:
super().__init__()
self.decide = Predict(CheckQuestionAggregation)
self.decide = Predict(signature)

def forward(self, question: str) -> Prediction:
"""
Expand All @@ -77,9 +79,9 @@ class AggregationAssessorCoT(Module):
Program that assesses whether a question requires aggregation.
"""

def __init__(self) -> None:
def __init__(self, signature: Type[AggregationAssessor]) -> None:
super().__init__()
self.decide = ChainOfThought(CheckQuestionAggregation)
self.decide = ChainOfThought(signature)

def forward(self, question: str) -> Prediction:
"""
Expand Down
27 changes: 25 additions & 2 deletions extra/prompt_tuning/tuning/signatures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
from .iql import CheckQuestionFiltering
from .iql import (
AggregationAssessor,
AggregationAssessorBaseline,
AggregationAssessorOptimized,
FilteringAssessor,
FilteringAssessorBaseline,
FilteringAssessorOptimized,
)

__all__ = ["CheckQuestionFiltering"]
SIGNATURES = {
AggregationAssessorBaseline.__name__: AggregationAssessorBaseline,
AggregationAssessorOptimized.__name__: AggregationAssessorOptimized,
FilteringAssessorBaseline.__name__: FilteringAssessorBaseline,
FilteringAssessorOptimized.__name__: FilteringAssessorOptimized,
}


__all__ = [
"AggregationAssessor",
"FilteringAssessor",
"AggregationAssessorBaseline",
"AggregationAssessorOptimized",
"FilteringAssessorBaseline",
"FilteringAssessorOptimized",
"SIGNATURES",
]
45 changes: 38 additions & 7 deletions extra/prompt_tuning/tuning/signatures/iql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from abc import ABC

from dspy import InputField, OutputField, Signature


class CheckQuestionFiltering(Signature):
class FilteringAssessor(Signature, ABC):
"""
Given a question, determine whether the answer requires initial data filtering in order to compute it.
Initial data filtering is a process in which the result set is reduced to only include the rows that
meet certain criteria specified in the question.
Abstract signature, should be implemented by a concrete signature,
describing the actual task for the LLM.
"""

question = InputField(
Expand All @@ -20,10 +21,26 @@ class CheckQuestionFiltering(Signature):
)


class CheckQuestionAggregation(Signature):
class FilteringAssessorBaseline(FilteringAssessor):
"""
Given a question, determine whether the answer requires data aggregation in order to compute it.
Data aggregation is a process in which we calculate a single values for a group of rows in the result set.
Given a question, determine whether the answer requires initial data filtering in order to compute it.
Initial data filtering is a process in which the result set is reduced to only include the rows that
meet certain criteria specified in the question.
"""


class FilteringAssessorOptimized(FilteringAssessor):
"""
Given a question, determine whether the answer requires initial data filtering in order to compute it.
Initial data filtering is a process in which the result set is reduced to only include the rows that
meet certain criteria specified in the question.
"""


class AggregationAssessor(Signature, ABC):
"""
Abstract signature, should be implemented by a concrete signature,
describing the actual task for the LLM.
"""

question = InputField(
Expand All @@ -33,3 +50,17 @@ class CheckQuestionAggregation(Signature):
prefix="Decision: ",
desc=("indicates whether the answer to the question requires data aggregation. (Respond with True or False)"),
)


class AggregationAssessorBaseline(AggregationAssessor):
"""
Given a question, determine whether the answer requires data aggregation in order to compute it.
Data aggregation is a process in which we calculate a single values for a group of rows in the result set.
"""


class AggregationAssessorOptimized(AggregationAssessor):
"""
Given a question, determine whether the answer requires data aggregation in order to compute it.
Data aggregation is a process in which we calculate a single values for a group of rows in the result set.
"""

0 comments on commit dde4452

Please sign in to comment.