Skip to content

Commit

Permalink
feat(extra): prompt tuning (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst authored Aug 9, 2024
1 parent 11a7b21 commit 2714e7c
Show file tree
Hide file tree
Showing 24 changed files with 407 additions and 10 deletions.
4 changes: 2 additions & 2 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ This folder contains scripts that produce reproducible timings and evaluation me

## Setup environment

Before installing any package, make sure you have Python 3.8 or higher installed on your machine. From the root directory of the project, install the dependencies:
Before installing any package, make sure you have Python 3.9 or higher installed on your machine. From the root directory of the project, install the dependencies:

```bash
pip install -e '.[benchmarks]'
pip install -e '.[dev]'
```

## Benchmark list
Expand Down
13 changes: 13 additions & 0 deletions extra/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Extra

This folder contains scripts for researching stuff related to dbally. Links are provided where descriptions exist:

- [`Prompt tuning`](prompt_tuning/README.md)

## Setup environment

Before installing any package, make sure you have Python 3.9 or higher installed on your machine. From the root directory of the project, install the dependencies:

```bash
pip install -e '.[dev]'
```
42 changes: 42 additions & 0 deletions extra/prompt_tuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Prompt tuning

This folder contains scripts for prompt tuning and evaluation. Prompts (programs) used in dbally:

- `FILTERING_ASSESSOR` - assesses whether a question requires filtering.

All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.io/) dataset. For now, one configuration is available to run the suite against the `superhero` database.

## Usage

Run evalution of filtering assessor baseline on the `superhero` database with `gpt-3.5-turbo`:

```bash
python evaluate.py program=filtering-assessor-baseline
```

Test multiple programs:

```bash
python evaluate.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot
```

Compare prompt performance on multiple LLMs:

```bash
python evaluate.py --multirun program=filtering-assessor-baseline llm=gpt-3.5-turbo,claude-3.5-sonnet
```

### Log to Neptune

Before running the evaluation with Neptune, configure the following environment variables:

```bash
export NEPTUNE_API_TOKEN="API_TOKEN"
export NEPTUNE_PROJECT="WORKSPACE_NAME/PROJECT_NAME"
```

Export evaluation results to Neptune:

```bash
python evaluate.py neptune=True
```
7 changes: 7 additions & 0 deletions extra/prompt_tuning/config/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- data: superhero
- llm: gpt-3.5-turbo
- program: filtering-assessor-baseline
- _self_

neptune: False
4 changes: 4 additions & 0 deletions extra/prompt_tuning/config/data/superhero.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
path: "micpst/bird-iql"
split: "dev"
db_ids: ["superhero"]
difficulties: ["simple", "moderate", "challenging"]
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/claude-3-haiku.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: claude-3-haiku-20240307
provider: Claude
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/claude-3-opus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: claude-3-opus-20240229
provider: Claude
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/claude-3.5-sonnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: claude-3-5-sonnet-20240620
provider: Claude
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/gpt-3.5-turbo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: gpt-3.5-turbo
provider: OpenAI
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/gpt-4-turbo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: gpt-4-turbo
provider: OpenAI
2 changes: 2 additions & 0 deletions extra/prompt_tuning/config/llm/gpt-4o.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: gpt-4o
provider: OpenAI
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
type: FILTERING_ASSESSOR
name: FilteringAssessorBaseline
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
type: FILTERING_ASSESSOR
name: FilteringAssessorCoT
101 changes: 101 additions & 0 deletions extra/prompt_tuning/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import asyncio
import logging
from enum import Enum
from pathlib import Path

import dspy
import hydra
import neptune
from dspy.evaluate import Evaluate
from neptune.utils import stringify_unsupported
from omegaconf import DictConfig
from tuning.loaders import IQLGenerationDataLoader
from tuning.metrics import filtering_assess_acc
from tuning.programs import PROGRAMS
from tuning.utils import save, serialize_results

logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("anthropic").setLevel(logging.ERROR)
log = logging.getLogger(__name__)


class EvaluationType(Enum):
"""
Enum representing the evaluation type.
"""

FILTERING_ASSESSOR = "FILTERING_ASSESSOR"


EVALUATION_DATALOADERS = {
EvaluationType.FILTERING_ASSESSOR.value: IQLGenerationDataLoader,
}

EVALUATION_METRICS = {
EvaluationType.FILTERING_ASSESSOR.value: filtering_assess_acc,
}


async def evaluate(config: DictConfig) -> None:
"""
Function running evaluation for all datasets and evaluation tasks defined in hydra config.
Args:
config: Hydra configuration.
"""
log.info("Starting evaluation: %s", config.program.name)

dataloader = EVALUATION_DATALOADERS[config.program.type](config)
metric = EVALUATION_METRICS[config.program.type]
program = PROGRAMS[config.program.name]()

dataset = await dataloader.load()

lm = dspy.__dict__[config.llm.provider](model=config.llm.model_name)
dspy.settings.configure(lm=lm)

evaluator = Evaluate(
devset=dataset,
metric=metric,
num_threads=32,
display_progress=True,
return_outputs=True,
)
metric, results = evaluator(program)

log.info("Evaluation finished. Saving results...")

output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
results_file = output_dir / "results.json"
save(results_file, results=serialize_results(results))

log.info("Evaluation results saved under directory: %s", output_dir)

if config.neptune:
run = neptune.init_run()
run["sys/tags"].add(
[
config.program.type,
config.program.name,
*config.data.db_ids,
*config.data.difficulties,
]
)
run["config"] = stringify_unsupported(config)
run["evaluation/metrics/ACC"] = stringify_unsupported(metric)
run["evaluation/results.json"].upload(results_file.as_posix())


@hydra.main(config_path="config", config_name="config", version_base="3.2")
def main(config: DictConfig) -> None:
"""
Function running evaluation for all datasets and evaluation tasks defined in hydra config.
Args:
config: Hydra configuration.
"""
asyncio.run(evaluate(config))


if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
Empty file.
69 changes: 69 additions & 0 deletions extra/prompt_tuning/tuning/loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from abc import ABC, abstractmethod
from typing import Dict, Iterable, List

import dspy.datasets
from dspy import Example


class DataLoader(ABC):
"""
Data loader.
"""

def __init__(self, config: Dict) -> None:
self.config = config

@abstractmethod
async def load(self) -> Iterable:
"""
Load the data.
Returns:
The loaded data.
"""


class HuggingFaceDataLoader(DataLoader):
"""
Hugging Face data loader.
"""

async def load(self) -> List[Example]:
"""
Load the data from Hugging Face.
Returns:
The loaded data.
"""
dataloader = dspy.datasets.DataLoader()
dataset = dataloader.from_huggingface(
dataset_name=self.config.data.path, split=self.config.data.split, input_keys=("question",)
)
return [
data
for data in dataset
if data["question"]
if (
data["db_id"] in self.config.data.db_ids
if self.config.data.db_ids
else True and data["difficulty"] in self.config.data.difficulties
if self.config.data.difficulties
else True
)
]


class IQLGenerationDataLoader(HuggingFaceDataLoader):
"""
Data loader for IQL generation evaluation.
"""

async def load(self) -> List[Example]:
"""
Load the data from Hugging Face and filter out samples without views.
Returns:
The loaded data.
"""
dataset = await super().load()
return [data for data in dataset if data["view_name"]]
3 changes: 3 additions & 0 deletions extra/prompt_tuning/tuning/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .iql import filtering_assess_acc

__all__ = ["filtering_assess_acc"]
19 changes: 19 additions & 0 deletions extra/prompt_tuning/tuning/metrics/iql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Dict

from dspy import Prediction


def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool:
"""
IQL decision metric.
Args:
gold: The ground truth data point.
pred: The prediction.
Returns:
The decision metric.
"""
return ((gold["iql_filters"] is None and not gold["iql_filters_unsupported"]) and not pred.decision) or (
(gold["iql_filters"] is not None or gold["iql_filters_unsupported"]) and pred.decision
)
8 changes: 8 additions & 0 deletions extra/prompt_tuning/tuning/programs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .iql import FilteringAssessorBaseline, FilteringAssessorCoT

PROGRAMS = {
FilteringAssessorBaseline.__name__: FilteringAssessorBaseline,
FilteringAssessorCoT.__name__: FilteringAssessorCoT,
}

__all__ = ["PROGRAMS", "FilteringAssessorBaseline", "FilteringAssessorCoT"]
49 changes: 49 additions & 0 deletions extra/prompt_tuning/tuning/programs/iql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from dspy import ChainOfThought, Module, Predict, Prediction

from ..signatures.iql import CheckQuestionFiltering


class FilteringAssessorBaseline(Module):
"""
Program that assesses whether a question requires filtering.
"""

def __init__(self) -> None:
super().__init__()
self.decide = Predict(CheckQuestionFiltering)

def forward(self, question: str) -> Prediction:
"""
Assess whether a question requires filtering.
Args:
question: The question to assess.
Returns:
The prediction.
"""
decision = self.decide(question=question).decision
return Prediction(decision=decision.lower() == "true")


class FilteringAssessorCoT(Module):
"""
Program that assesses whether a question requires filtering.
"""

def __init__(self) -> None:
super().__init__()
self.decide = ChainOfThought(CheckQuestionFiltering)

def forward(self, question: str) -> Prediction:
"""
Assess whether a question requires filtering.
Args:
question: The question to assess.
Returns:
The prediction.
"""
decision = self.decide(question=question).decision
return Prediction(decision=decision.lower() == "true")
3 changes: 3 additions & 0 deletions extra/prompt_tuning/tuning/signatures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .iql import CheckQuestionFiltering

__all__ = ["CheckQuestionFiltering"]
20 changes: 20 additions & 0 deletions extra/prompt_tuning/tuning/signatures/iql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dspy import InputField, OutputField, Signature


class CheckQuestionFiltering(Signature):
"""
Given a question, determine whether the answer requires initial data filtering in order to compute it.
Initial data filtering is a process in which the result set is reduced to only include the rows that
meet certain criteria specified in the question.
"""

question = InputField(
prefix="Question: ",
)
decision = OutputField(
prefix="Decision: ",
desc=(
"indicates whether the answer to the question requires initial data filtering. "
"(Respond with True or False)"
),
)
Loading

0 comments on commit 2714e7c

Please sign in to comment.