-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
407 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Extra | ||
|
||
This folder contains scripts for researching stuff related to dbally. Links are provided where descriptions exist: | ||
|
||
- [`Prompt tuning`](prompt_tuning/README.md) | ||
|
||
## Setup environment | ||
|
||
Before installing any package, make sure you have Python 3.9 or higher installed on your machine. From the root directory of the project, install the dependencies: | ||
|
||
```bash | ||
pip install -e '.[dev]' | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Prompt tuning | ||
|
||
This folder contains scripts for prompt tuning and evaluation. Prompts (programs) used in dbally: | ||
|
||
- `FILTERING_ASSESSOR` - assesses whether a question requires filtering. | ||
|
||
All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.io/) dataset. For now, one configuration is available to run the suite against the `superhero` database. | ||
|
||
## Usage | ||
|
||
Run evalution of filtering assessor baseline on the `superhero` database with `gpt-3.5-turbo`: | ||
|
||
```bash | ||
python evaluate.py program=filtering-assessor-baseline | ||
``` | ||
|
||
Test multiple programs: | ||
|
||
```bash | ||
python evaluate.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot | ||
``` | ||
|
||
Compare prompt performance on multiple LLMs: | ||
|
||
```bash | ||
python evaluate.py --multirun program=filtering-assessor-baseline llm=gpt-3.5-turbo,claude-3.5-sonnet | ||
``` | ||
|
||
### Log to Neptune | ||
|
||
Before running the evaluation with Neptune, configure the following environment variables: | ||
|
||
```bash | ||
export NEPTUNE_API_TOKEN="API_TOKEN" | ||
export NEPTUNE_PROJECT="WORKSPACE_NAME/PROJECT_NAME" | ||
``` | ||
|
||
Export evaluation results to Neptune: | ||
|
||
```bash | ||
python evaluate.py neptune=True | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
defaults: | ||
- data: superhero | ||
- llm: gpt-3.5-turbo | ||
- program: filtering-assessor-baseline | ||
- _self_ | ||
|
||
neptune: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
path: "micpst/bird-iql" | ||
split: "dev" | ||
db_ids: ["superhero"] | ||
difficulties: ["simple", "moderate", "challenging"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
model_name: claude-3-haiku-20240307 | ||
provider: Claude |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
model_name: claude-3-opus-20240229 | ||
provider: Claude |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
model_name: claude-3-5-sonnet-20240620 | ||
provider: Claude |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
model_name: gpt-3.5-turbo | ||
provider: OpenAI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
model_name: gpt-4-turbo | ||
provider: OpenAI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
model_name: gpt-4o | ||
provider: OpenAI |
2 changes: 2 additions & 0 deletions
2
extra/prompt_tuning/config/program/filtering-assessor-baseline.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
type: FILTERING_ASSESSOR | ||
name: FilteringAssessorBaseline |
2 changes: 2 additions & 0 deletions
2
extra/prompt_tuning/config/program/filtering-assessor-cot.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
type: FILTERING_ASSESSOR | ||
name: FilteringAssessorCoT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import asyncio | ||
import logging | ||
from enum import Enum | ||
from pathlib import Path | ||
|
||
import dspy | ||
import hydra | ||
import neptune | ||
from dspy.evaluate import Evaluate | ||
from neptune.utils import stringify_unsupported | ||
from omegaconf import DictConfig | ||
from tuning.loaders import IQLGenerationDataLoader | ||
from tuning.metrics import filtering_assess_acc | ||
from tuning.programs import PROGRAMS | ||
from tuning.utils import save, serialize_results | ||
|
||
logging.getLogger("httpx").setLevel(logging.ERROR) | ||
logging.getLogger("anthropic").setLevel(logging.ERROR) | ||
log = logging.getLogger(__name__) | ||
|
||
|
||
class EvaluationType(Enum): | ||
""" | ||
Enum representing the evaluation type. | ||
""" | ||
|
||
FILTERING_ASSESSOR = "FILTERING_ASSESSOR" | ||
|
||
|
||
EVALUATION_DATALOADERS = { | ||
EvaluationType.FILTERING_ASSESSOR.value: IQLGenerationDataLoader, | ||
} | ||
|
||
EVALUATION_METRICS = { | ||
EvaluationType.FILTERING_ASSESSOR.value: filtering_assess_acc, | ||
} | ||
|
||
|
||
async def evaluate(config: DictConfig) -> None: | ||
""" | ||
Function running evaluation for all datasets and evaluation tasks defined in hydra config. | ||
Args: | ||
config: Hydra configuration. | ||
""" | ||
log.info("Starting evaluation: %s", config.program.name) | ||
|
||
dataloader = EVALUATION_DATALOADERS[config.program.type](config) | ||
metric = EVALUATION_METRICS[config.program.type] | ||
program = PROGRAMS[config.program.name]() | ||
|
||
dataset = await dataloader.load() | ||
|
||
lm = dspy.__dict__[config.llm.provider](model=config.llm.model_name) | ||
dspy.settings.configure(lm=lm) | ||
|
||
evaluator = Evaluate( | ||
devset=dataset, | ||
metric=metric, | ||
num_threads=32, | ||
display_progress=True, | ||
return_outputs=True, | ||
) | ||
metric, results = evaluator(program) | ||
|
||
log.info("Evaluation finished. Saving results...") | ||
|
||
output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) | ||
results_file = output_dir / "results.json" | ||
save(results_file, results=serialize_results(results)) | ||
|
||
log.info("Evaluation results saved under directory: %s", output_dir) | ||
|
||
if config.neptune: | ||
run = neptune.init_run() | ||
run["sys/tags"].add( | ||
[ | ||
config.program.type, | ||
config.program.name, | ||
*config.data.db_ids, | ||
*config.data.difficulties, | ||
] | ||
) | ||
run["config"] = stringify_unsupported(config) | ||
run["evaluation/metrics/ACC"] = stringify_unsupported(metric) | ||
run["evaluation/results.json"].upload(results_file.as_posix()) | ||
|
||
|
||
@hydra.main(config_path="config", config_name="config", version_base="3.2") | ||
def main(config: DictConfig) -> None: | ||
""" | ||
Function running evaluation for all datasets and evaluation tasks defined in hydra config. | ||
Args: | ||
config: Hydra configuration. | ||
""" | ||
asyncio.run(evaluate(config)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() # pylint: disable=no-value-for-parameter |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, Iterable, List | ||
|
||
import dspy.datasets | ||
from dspy import Example | ||
|
||
|
||
class DataLoader(ABC): | ||
""" | ||
Data loader. | ||
""" | ||
|
||
def __init__(self, config: Dict) -> None: | ||
self.config = config | ||
|
||
@abstractmethod | ||
async def load(self) -> Iterable: | ||
""" | ||
Load the data. | ||
Returns: | ||
The loaded data. | ||
""" | ||
|
||
|
||
class HuggingFaceDataLoader(DataLoader): | ||
""" | ||
Hugging Face data loader. | ||
""" | ||
|
||
async def load(self) -> List[Example]: | ||
""" | ||
Load the data from Hugging Face. | ||
Returns: | ||
The loaded data. | ||
""" | ||
dataloader = dspy.datasets.DataLoader() | ||
dataset = dataloader.from_huggingface( | ||
dataset_name=self.config.data.path, split=self.config.data.split, input_keys=("question",) | ||
) | ||
return [ | ||
data | ||
for data in dataset | ||
if data["question"] | ||
if ( | ||
data["db_id"] in self.config.data.db_ids | ||
if self.config.data.db_ids | ||
else True and data["difficulty"] in self.config.data.difficulties | ||
if self.config.data.difficulties | ||
else True | ||
) | ||
] | ||
|
||
|
||
class IQLGenerationDataLoader(HuggingFaceDataLoader): | ||
""" | ||
Data loader for IQL generation evaluation. | ||
""" | ||
|
||
async def load(self) -> List[Example]: | ||
""" | ||
Load the data from Hugging Face and filter out samples without views. | ||
Returns: | ||
The loaded data. | ||
""" | ||
dataset = await super().load() | ||
return [data for data in dataset if data["view_name"]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .iql import filtering_assess_acc | ||
|
||
__all__ = ["filtering_assess_acc"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import Dict | ||
|
||
from dspy import Prediction | ||
|
||
|
||
def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool: | ||
""" | ||
IQL decision metric. | ||
Args: | ||
gold: The ground truth data point. | ||
pred: The prediction. | ||
Returns: | ||
The decision metric. | ||
""" | ||
return ((gold["iql_filters"] is None and not gold["iql_filters_unsupported"]) and not pred.decision) or ( | ||
(gold["iql_filters"] is not None or gold["iql_filters_unsupported"]) and pred.decision | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .iql import FilteringAssessorBaseline, FilteringAssessorCoT | ||
|
||
PROGRAMS = { | ||
FilteringAssessorBaseline.__name__: FilteringAssessorBaseline, | ||
FilteringAssessorCoT.__name__: FilteringAssessorCoT, | ||
} | ||
|
||
__all__ = ["PROGRAMS", "FilteringAssessorBaseline", "FilteringAssessorCoT"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from dspy import ChainOfThought, Module, Predict, Prediction | ||
|
||
from ..signatures.iql import CheckQuestionFiltering | ||
|
||
|
||
class FilteringAssessorBaseline(Module): | ||
""" | ||
Program that assesses whether a question requires filtering. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.decide = Predict(CheckQuestionFiltering) | ||
|
||
def forward(self, question: str) -> Prediction: | ||
""" | ||
Assess whether a question requires filtering. | ||
Args: | ||
question: The question to assess. | ||
Returns: | ||
The prediction. | ||
""" | ||
decision = self.decide(question=question).decision | ||
return Prediction(decision=decision.lower() == "true") | ||
|
||
|
||
class FilteringAssessorCoT(Module): | ||
""" | ||
Program that assesses whether a question requires filtering. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.decide = ChainOfThought(CheckQuestionFiltering) | ||
|
||
def forward(self, question: str) -> Prediction: | ||
""" | ||
Assess whether a question requires filtering. | ||
Args: | ||
question: The question to assess. | ||
Returns: | ||
The prediction. | ||
""" | ||
decision = self.decide(question=question).decision | ||
return Prediction(decision=decision.lower() == "true") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .iql import CheckQuestionFiltering | ||
|
||
__all__ = ["CheckQuestionFiltering"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from dspy import InputField, OutputField, Signature | ||
|
||
|
||
class CheckQuestionFiltering(Signature): | ||
""" | ||
Given a question, determine whether the answer requires initial data filtering in order to compute it. | ||
Initial data filtering is a process in which the result set is reduced to only include the rows that | ||
meet certain criteria specified in the question. | ||
""" | ||
|
||
question = InputField( | ||
prefix="Question: ", | ||
) | ||
decision = OutputField( | ||
prefix="Decision: ", | ||
desc=( | ||
"indicates whether the answer to the question requires initial data filtering. " | ||
"(Respond with True or False)" | ||
), | ||
) |
Oops, something went wrong.