Skip to content

Commit

Permalink
fix metrics and ag view
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Sep 2, 2024
1 parent d3cbc37 commit 21b7502
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 23 deletions.
4 changes: 4 additions & 0 deletions benchmarks/sql/bench/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLAggregationCorrectness,
IQLAggregationParseability,
IQLFiltersAccuracy,
IQLFiltersCorrectness,
IQLFiltersParseability,
Expand All @@ -18,11 +20,13 @@
"FilteringAccuracy",
"FilteringPrecision",
"FilteringRecall",
"IQLAggregationParseability",
"IQLFiltersAccuracy",
"IQLFiltersPrecision",
"IQLFiltersRecall",
"IQLFiltersParseability",
"IQLFiltersCorrectness",
"IQLAggregationCorrectness",
"SQLExactMatch",
"ViewSelectionAccuracy",
"ViewSelectionPrecision",
Expand Down
91 changes: 74 additions & 17 deletions benchmarks/sql/bench/metrics/iql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC
from typing import Any, Dict, List

from ..pipelines import EvaluationResult
Expand All @@ -19,12 +20,21 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
Returns:
Filtering accuracy.
"""
results = [result for result in results if result.reference.iql and result.prediction.iql]
results = [
result
for result in results
if result.reference.iql
and result.prediction.iql
and result.reference.view_name
and result.prediction.view_name
and result.reference.iql.filters.generated
and result.prediction.iql.filters.generated
]
return {
"DM/FLT/ACC": (
sum(
isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source))
and result.prediction.iql.filters.unsupported == result.reference.iql.filters.unsupported
(result.reference.iql.filters.source is not None or result.reference.iql.filters.unsupported)
== (result.prediction.iql.filters.source is not None or result.prediction.iql.filters.unsupported)
for result in results
)
/ len(results)
Expand Down Expand Up @@ -222,11 +232,14 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
}


class IQLFiltersParseability(Metric):
class IQLParseability(Metric, ABC):
"""
IQL filters parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs.
"""

prefix: str
iql: str

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the IQL filters parseability.
Expand All @@ -241,46 +254,90 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (result.reference.iql.filters and result.prediction.iql.filters)
and (result.reference.iql.filters.source and result.prediction.iql.filters.source)
and (getattr(result.reference.iql, self.iql) and getattr(result.prediction.iql, self.iql))
and (getattr(result.reference.iql, self.iql).source and getattr(result.prediction.iql, self.iql).source)
]
return {
"IQL/FLT/PARSEABILITY": (
sum(result.prediction.iql.filters.valid for result in results) / len(results) if results else None
f"IQL/{self.prefix}/PARSEABILITY": (
sum(getattr(result.prediction.iql, self.iql).valid for result in results) / len(results)
if results
else None
)
}


class IQLFiltersCorrectness(Metric):
class IQLFiltersParseability(IQLParseability):
"""
IQL filters correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
IQL filters parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs.
"""

prefix: str = "FLT"
iql: str = "filters"


class IQLAggregationParseability(IQLParseability):
"""
IQL aggregation parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs.
"""

prefix: str = "AGG"
iql: str = "aggregation"


class IQLCorrectness(Metric, ABC):
"""
IQL correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
"""

prefix: str
iql: str

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the IQL filters correctness.
Computes the IQL correctness.
Args:
results: List of evaluation results.
Returns:
IQL filters correctness.
IQL correctness.
"""
results = [
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (
result.reference.iql.filters.source
and result.prediction.iql.filters.source
and result.prediction.iql.filters.valid
getattr(result.reference.iql, self.iql).source
and getattr(result.prediction.iql, self.iql).source
and getattr(result.prediction.iql, self.iql).valid
)
]
return {
"IQL/FLT/CORRECTNESS": (
sum(result.prediction.iql.filters.source == result.reference.iql.filters.source for result in results)
f"IQL/{self.prefix}/CORRECTNESS": (
sum(
getattr(result.prediction.iql, self.iql).source == getattr(result.reference.iql, self.iql).source
for result in results
)
/ len(results)
if results
else None
)
}


class IQLFiltersCorrectness(IQLCorrectness):
"""
IQL filters correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
"""

prefix: str = "FLT"
iql: str = "filters"


class IQLAggregationCorrectness(IQLCorrectness):
"""
IQL aggregation correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
"""

prefix: str = "AGG"
iql: str = "aggregation"
14 changes: 8 additions & 6 deletions benchmarks/sql/bench/views/structured/superhero.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def count_superheroes(self) -> Select:
Returns:
The superheros count.
"""
return self.data.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id)
return self.select.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id)

@view_aggregation()
def average_height(self) -> Select:
Expand All @@ -303,7 +303,9 @@ def average_height(self) -> Select:
Returns:
The superheros average height.
"""
return self.data.with_only_columns(func.avg(Superhero.height_cm).label("average_height")).group_by(Superhero.id)
return self.select.with_only_columns(func.avg(Superhero.height_cm).label("average_height")).group_by(
Superhero.id
)


class SuperheroColourFilterMixin:
Expand Down Expand Up @@ -387,7 +389,7 @@ def percentage_of_eye_colour(self, eye_colour: str) -> Select:
Returns:
The percentage of objects with eye colour.
"""
return self.data.with_only_columns(
return self.select.with_only_columns(
(
cast(func.count(case((self.eye_colour.colour == eye_colour, Superhero.id), else_=None)), Float)
* 100
Expand Down Expand Up @@ -431,7 +433,7 @@ def percentage_of_publisher(self, publisher_name: str) -> Select:
Returns:
The percentage of objects with publisher.
"""
return self.data.with_only_columns(
return self.select.with_only_columns(
(
cast(func.count(case((Publisher.publisher_name == publisher_name, Superhero.id), else_=None)), Float)
* 100
Expand Down Expand Up @@ -475,7 +477,7 @@ def percentage_of_alignment(self, alignment: Literal["Good", "Bad", "Neutral", "
Returns:
The percentage of objects with alignment.
"""
return self.data.with_only_columns(
return self.select.with_only_columns(
(
cast(func.count(case((Alignment.alignment == alignment, Superhero.id), else_=None)), Float)
* 100
Expand Down Expand Up @@ -519,7 +521,7 @@ def percentage_of_gender(self, gender: Literal["Male", "Female", "N/A"]) -> Sele
Returns:
The percentage of objects with gender.
"""
return self.data.with_only_columns(
return self.select.with_only_columns(
(
cast(func.count(case((Gender.gender == gender, Superhero.id), else_=None)), Float)
* 100
Expand Down

0 comments on commit 21b7502

Please sign in to comment.