Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Jul 30, 2024
1 parent 1f0214a commit ef0b0be
Show file tree
Hide file tree
Showing 22 changed files with 757 additions and 729 deletions.
10 changes: 8 additions & 2 deletions benchmarks/sql/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ Before starting, download the `superhero.sqlite` database file from [BIRD](https
Run the whole suite on the `superhero` database with `gpt-3.5-turbo`:

```bash
python bench.py --multirun setup=iql-view,sql-view,collection data=superhero
python bench.py --multirun setup=iql-view,sql-view,collection
```

Run on multiple databases:

```bash
python bench.py setup=sql-view setup/views/[email protected]='[superhero,...]' data=bird
```

You can also run each evaluation separately or in subgroups:
Expand All @@ -34,7 +40,7 @@ python bench.py --multirun setup=iql-view setup/llm=gpt-3.5-turbo,claude-3.5-son
python bench.py --multirun setup=sql-view setup/llm=gpt-3.5-turbo,claude-3.5-sonnet
```

For the `collection` steup, you need to specify models for both the view selection and the IQL generation step:
For the `collection` setup, you need to specify models for both the view selection and the IQL generation step:

```bash
python bench.py --multirun \
Expand Down
52 changes: 32 additions & 20 deletions benchmarks/sql/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
from bench.evaluator import Evaluator
from bench.loaders import CollectionDataLoader, IQLViewDataLoader, SQLViewDataLoader
from bench.metrics import (
ExactMatchAggregationIQL,
ExactMatchFiltersIQL,
ExactMatchIQL,
ExactMatchSQL,
ExecutionAccuracy,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLFiltersAccuracy,
IQLFiltersCorrectness,
IQLFiltersParseability,
IQLFiltersPrecision,
IQLFiltersRecall,
MetricSet,
UnsupportedIQL,
ValidIQL,
SQLExactMatch,
ViewSelectionAccuracy,
ViewSelectionPrecision,
ViewSelectionRecall,
)
from bench.pipelines import CollectionEvaluationPipeline, IQLViewEvaluationPipeline, SQLViewEvaluationPipeline
from bench.utils import save
Expand Down Expand Up @@ -52,26 +57,33 @@ class EvaluationType(Enum):

EVALUATION_METRICS = {
EvaluationType.IQL.value: MetricSet(
ExactMatchIQL,
ExactMatchFiltersIQL,
ExactMatchAggregationIQL,
ValidIQL,
ViewSelectionAccuracy,
UnsupportedIQL,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLFiltersAccuracy,
IQLFiltersPrecision,
IQLFiltersRecall,
IQLFiltersParseability,
IQLFiltersCorrectness,
ExecutionAccuracy,
),
EvaluationType.SQL.value: MetricSet(
ExactMatchSQL,
SQLExactMatch,
ExecutionAccuracy,
),
EvaluationType.E2E.value: MetricSet(
ExactMatchIQL,
ExactMatchFiltersIQL,
ExactMatchAggregationIQL,
ValidIQL,
UnsupportedIQL,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLFiltersAccuracy,
IQLFiltersPrecision,
IQLFiltersRecall,
IQLFiltersParseability,
IQLFiltersCorrectness,
ViewSelectionAccuracy,
ExactMatchSQL,
ViewSelectionPrecision,
ViewSelectionRecall,
SQLExactMatch,
ExecutionAccuracy,
),
}
Expand Down Expand Up @@ -113,7 +125,7 @@ async def bench(config: DictConfig) -> None:
run["sys/tags"].add(
[
config.setup.name,
config.data.db_id,
*config.data.db_ids,
*config.data.difficulties,
]
)
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/sql/bench/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from dataclasses import asdict
from typing import Any, Callable, Dict, List, Tuple

from datasets import Dataset
Expand Down Expand Up @@ -81,7 +82,7 @@ def _results_processor(self, results: List[EvaluationResult]) -> Dict[str, Any]:
Returns:
The processed results.
"""
return {"results": [result.dict() for result in results]}
return {"results": [asdict(result) for result in results]}

def _compute_metrics(self, metrics: MetricSet, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/sql/bench/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ async def load(self) -> Dataset:
"""
dataset = await super().load()
return dataset.filter(
lambda x: x["db_id"] == self.config.data.db_id
lambda x: x["db_id"] in self.config.data.db_ids
and x["difficulty"] in self.config.data.difficulties
and x["view"] is not None
and x["view_name"] is not None
)


Expand All @@ -74,7 +74,7 @@ async def load(self) -> Dataset:
"""
dataset = await super().load()
return dataset.filter(
lambda x: x["db_id"] == self.config.data.db_id and x["difficulty"] in self.config.data.difficulties
lambda x: x["db_id"] in self.config.data.db_ids and x["difficulty"] in self.config.data.difficulties
)


Expand All @@ -92,5 +92,5 @@ async def load(self) -> Dataset:
"""
dataset = await super().load()
return dataset.filter(
lambda x: x["db_id"] == self.config.data.db_id and x["difficulty"] in self.config.data.difficulties
lambda x: x["db_id"] in self.config.data.db_ids and x["difficulty"] in self.config.data.difficulties
)
32 changes: 23 additions & 9 deletions benchmarks/sql/bench/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
from .base import Metric, MetricSet
from .iql import ExactMatchAggregationIQL, ExactMatchFiltersIQL, ExactMatchIQL, UnsupportedIQL, ValidIQL
from .selector import ViewSelectionAccuracy
from .sql import ExactMatchSQL, ExecutionAccuracy
from .iql import (
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLFiltersAccuracy,
IQLFiltersCorrectness,
IQLFiltersParseability,
IQLFiltersPrecision,
IQLFiltersRecall,
)
from .selector import ViewSelectionAccuracy, ViewSelectionPrecision, ViewSelectionRecall
from .sql import ExecutionAccuracy, SQLExactMatch

__all__ = [
"Metric",
"MetricSet",
"ExactMatchSQL",
"ExactMatchIQL",
"ExactMatchFiltersIQL",
"ExactMatchAggregationIQL",
"ValidIQL",
"FilteringAccuracy",
"FilteringPrecision",
"FilteringRecall",
"IQLFiltersAccuracy",
"IQLFiltersPrecision",
"IQLFiltersRecall",
"IQLFiltersParseability",
"IQLFiltersCorrectness",
"SQLExactMatch",
"ViewSelectionAccuracy",
"UnsupportedIQL",
"ViewSelectionPrecision",
"ViewSelectionRecall",
"ExecutionAccuracy",
]
Loading

0 comments on commit ef0b0be

Please sign in to comment.