Skip to content

Commit

Permalink
feat(benchmarks): aggregation benchmarks (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst authored Sep 25, 2024
1 parent 61e82b7 commit 26033f7
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 183 deletions.
9 changes: 9 additions & 0 deletions benchmarks/sql/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
from bench.evaluator import Evaluator
from bench.loaders import CollectionDataLoader, IQLViewDataLoader, SQLViewDataLoader
from bench.metrics import (
AggregationAccuracy,
ExecutionAccuracy,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLAggregationCorrectness,
IQLAggregationParseability,
IQLFiltersAccuracy,
IQLFiltersCorrectness,
IQLFiltersParseability,
Expand Down Expand Up @@ -57,9 +60,12 @@ class EvaluationType(Enum):

EVALUATION_METRICS = {
EvaluationType.IQL.value: MetricSet(
AggregationAccuracy,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLAggregationParseability,
IQLAggregationCorrectness,
IQLFiltersAccuracy,
IQLFiltersPrecision,
IQLFiltersRecall,
Expand All @@ -72,9 +78,12 @@ class EvaluationType(Enum):
ExecutionAccuracy,
),
EvaluationType.E2E.value: MetricSet(
AggregationAccuracy,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLAggregationParseability,
IQLAggregationCorrectness,
IQLFiltersAccuracy,
IQLFiltersPrecision,
IQLFiltersRecall,
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/sql/bench/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from .base import Metric, MetricSet
from .iql import (
AggregationAccuracy,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLAggregationCorrectness,
IQLAggregationParseability,
IQLFiltersAccuracy,
IQLFiltersCorrectness,
IQLFiltersParseability,
Expand All @@ -15,14 +18,17 @@
__all__ = [
"Metric",
"MetricSet",
"AggregationAccuracy",
"FilteringAccuracy",
"FilteringPrecision",
"FilteringRecall",
"IQLAggregationParseability",
"IQLFiltersAccuracy",
"IQLFiltersPrecision",
"IQLFiltersRecall",
"IQLFiltersParseability",
"IQLFiltersCorrectness",
"IQLAggregationCorrectness",
"SQLExactMatch",
"ViewSelectionAccuracy",
"ViewSelectionPrecision",
Expand Down
128 changes: 106 additions & 22 deletions benchmarks/sql/bench/metrics/iql.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,49 @@
from abc import ABC
from typing import Any, Dict, List

from ..pipelines import EvaluationResult
from .base import Metric


class FilteringAccuracy(Metric):
class AssessingAccuracy(Metric, ABC):
"""
Filtering accuracy is proportion of correct decisions (to filter or not) out of all decisions made.
Assessing accuracy is proportion of correct decisions out of all decisions made.
"""

prefix: str
iql: str

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the filtering accuracy.
Computes the assessing accuracy.
Args:
results: List of evaluation results.
Returns:
Filtering accuracy.
Assessing accuracy.
"""
results = [result for result in results if result.reference.iql and result.prediction.iql]
results = [
result
for result in results
if result.reference.iql
and result.prediction.iql
and result.reference.view_name
and result.prediction.view_name
and getattr(result.reference.iql, self.iql).generated
and getattr(result.prediction.iql, self.iql).generated
]
return {
"DM/FLT/ACC": (
f"DM/{self.prefix}/ACC": (
sum(
isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source))
and result.prediction.iql.filters.unsupported == result.reference.iql.filters.unsupported
(
getattr(result.reference.iql, self.iql).source is not None
or getattr(result.reference.iql, self.iql).unsupported
)
== (
getattr(result.prediction.iql, self.iql).source is not None
or getattr(result.prediction.iql, self.iql).unsupported
)
for result in results
)
/ len(results)
Expand All @@ -34,6 +53,24 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
}


class FilteringAccuracy(AssessingAccuracy):
"""
Filtering accuracy is proportion of correct decisions (to filter or not) out of all decisions made.
"""

prefix: str = "FLT"
iql: str = "filters"


class AggregationAccuracy(AssessingAccuracy):
"""
Aggregation accuracy is proportion of correct decisions (to aggregate or not) out of all decisions made.
"""

prefix: str = "AGG"
iql: str = "aggregation"


class FilteringPrecision(Metric):
"""
Filtering precision is proportion of correct decisions to filter out of all decisions to filter.
Expand Down Expand Up @@ -222,11 +259,14 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
}


class IQLFiltersParseability(Metric):
class IQLParseability(Metric, ABC):
"""
IQL filters parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs.
"""

prefix: str
iql: str

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the IQL filters parseability.
Expand All @@ -241,46 +281,90 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (result.reference.iql.filters and result.prediction.iql.filters)
and (result.reference.iql.filters.source and result.prediction.iql.filters.source)
and (getattr(result.reference.iql, self.iql) and getattr(result.prediction.iql, self.iql))
and (getattr(result.reference.iql, self.iql).source and getattr(result.prediction.iql, self.iql).source)
]
return {
"IQL/FLT/PARSEABILITY": (
sum(result.prediction.iql.filters.valid for result in results) / len(results) if results else None
f"IQL/{self.prefix}/PARSEABILITY": (
sum(getattr(result.prediction.iql, self.iql).valid for result in results) / len(results)
if results
else None
)
}


class IQLFiltersCorrectness(Metric):
class IQLFiltersParseability(IQLParseability):
"""
IQL filters correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
IQL filters parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs.
"""

prefix: str = "FLT"
iql: str = "filters"


class IQLAggregationParseability(IQLParseability):
"""
IQL aggregation parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs.
"""

prefix: str = "AGG"
iql: str = "aggregation"


class IQLCorrectness(Metric, ABC):
"""
IQL correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
"""

prefix: str
iql: str

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the IQL filters correctness.
Computes the IQL correctness.
Args:
results: List of evaluation results.
Returns:
IQL filters correctness.
IQL correctness.
"""
results = [
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (
result.reference.iql.filters.source
and result.prediction.iql.filters.source
and result.prediction.iql.filters.valid
getattr(result.reference.iql, self.iql).source
and getattr(result.prediction.iql, self.iql).source
and getattr(result.prediction.iql, self.iql).valid
)
]
return {
"IQL/FLT/CORRECTNESS": (
sum(result.prediction.iql.filters.source == result.reference.iql.filters.source for result in results)
f"IQL/{self.prefix}/CORRECTNESS": (
sum(
getattr(result.prediction.iql, self.iql).source == getattr(result.reference.iql, self.iql).source
for result in results
)
/ len(results)
if results
else None
)
}


class IQLFiltersCorrectness(IQLCorrectness):
"""
IQL filters correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
"""

prefix: str = "FLT"
iql: str = "filters"


class IQLAggregationCorrectness(IQLCorrectness):
"""
IQL aggregation correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
"""

prefix: str = "AGG"
iql: str = "aggregation"
9 changes: 6 additions & 3 deletions benchmarks/sql/bench/metrics/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
Returns:
The exact match ratio.
"""
results = [result for result in results if result.reference.sql and result.prediction.sql]
return {
"SQL/EM": (
sum(result.prediction.sql == result.reference.sql for result in results) / len(results)
Expand Down Expand Up @@ -95,6 +96,7 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
Returns:
Execution accuracy score and valid efficiency score.
"""
results = [result for result in results if result.reference.sql and result.prediction.sql]
accurate_results = [result for result in results if self._execution_accuracy(result)]
return {
"EX": len(accurate_results) / len(results) if results else None,
Expand All @@ -121,9 +123,6 @@ def _execution_accuracy(self, result: EvaluationResult) -> bool:
Returns:
True if the execution results are identical, False otherwise.
"""
if result.prediction.sql is None:
return False

try:
ref_results = self._execute_query(result.reference.sql, result.db_id)
pred_results = self._execute_query(result.prediction.sql, result.db_id)
Expand All @@ -138,6 +137,10 @@ def _execution_accuracy(self, result: EvaluationResult) -> bool:
if reference.shape[0] != prediction.shape[0]:
return False

# If both dataframes have only one column, compare the values directly
if reference.shape[1] == prediction.shape[1] == 1:
return reference.iloc[:, 0].equals(prediction.iloc[:, 0])

# Returned view may have the same columns, or more columns than the ground truth
if not reference.columns.isin(prediction.columns).all():
return False
Expand Down
Loading

0 comments on commit 26033f7

Please sign in to comment.