From 21b7502a656891250d4f24034d1c907709e23864 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 2 Sep 2024 23:32:02 +0200 Subject: [PATCH] fix metrics and ag view --- benchmarks/sql/bench/metrics/__init__.py | 4 + benchmarks/sql/bench/metrics/iql.py | 91 +++++++++++++++---- .../sql/bench/views/structured/superhero.py | 14 +-- 3 files changed, 86 insertions(+), 23 deletions(-) diff --git a/benchmarks/sql/bench/metrics/__init__.py b/benchmarks/sql/bench/metrics/__init__.py index f0edc124..3684635f 100644 --- a/benchmarks/sql/bench/metrics/__init__.py +++ b/benchmarks/sql/bench/metrics/__init__.py @@ -3,6 +3,8 @@ FilteringAccuracy, FilteringPrecision, FilteringRecall, + IQLAggregationCorrectness, + IQLAggregationParseability, IQLFiltersAccuracy, IQLFiltersCorrectness, IQLFiltersParseability, @@ -18,11 +20,13 @@ "FilteringAccuracy", "FilteringPrecision", "FilteringRecall", + "IQLAggregationParseability", "IQLFiltersAccuracy", "IQLFiltersPrecision", "IQLFiltersRecall", "IQLFiltersParseability", "IQLFiltersCorrectness", + "IQLAggregationCorrectness", "SQLExactMatch", "ViewSelectionAccuracy", "ViewSelectionPrecision", diff --git a/benchmarks/sql/bench/metrics/iql.py b/benchmarks/sql/bench/metrics/iql.py index 07cf90c9..118c73d7 100644 --- a/benchmarks/sql/bench/metrics/iql.py +++ b/benchmarks/sql/bench/metrics/iql.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import Any, Dict, List from ..pipelines import EvaluationResult @@ -19,12 +20,21 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: Returns: Filtering accuracy. """ - results = [result for result in results if result.reference.iql and result.prediction.iql] + results = [ + result + for result in results + if result.reference.iql + and result.prediction.iql + and result.reference.view_name + and result.prediction.view_name + and result.reference.iql.filters.generated + and result.prediction.iql.filters.generated + ] return { "DM/FLT/ACC": ( sum( - isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source)) - and result.prediction.iql.filters.unsupported == result.reference.iql.filters.unsupported + (result.reference.iql.filters.source is not None or result.reference.iql.filters.unsupported) + == (result.prediction.iql.filters.source is not None or result.prediction.iql.filters.unsupported) for result in results ) / len(results) @@ -222,11 +232,14 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: } -class IQLFiltersParseability(Metric): +class IQLParseability(Metric, ABC): """ IQL filters parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs. """ + prefix: str + iql: str + def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: """ Computes the IQL filters parseability. @@ -241,46 +254,90 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: result for result in results if (result.reference.iql and result.prediction.iql) - and (result.reference.iql.filters and result.prediction.iql.filters) - and (result.reference.iql.filters.source and result.prediction.iql.filters.source) + and (getattr(result.reference.iql, self.iql) and getattr(result.prediction.iql, self.iql)) + and (getattr(result.reference.iql, self.iql).source and getattr(result.prediction.iql, self.iql).source) ] return { - "IQL/FLT/PARSEABILITY": ( - sum(result.prediction.iql.filters.valid for result in results) / len(results) if results else None + f"IQL/{self.prefix}/PARSEABILITY": ( + sum(getattr(result.prediction.iql, self.iql).valid for result in results) / len(results) + if results + else None ) } -class IQLFiltersCorrectness(Metric): +class IQLFiltersParseability(IQLParseability): """ - IQL filters correctness is proportion of IQLs that produce correct results out of all parseable IQLs. + IQL filters parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs. + """ + + prefix: str = "FLT" + iql: str = "filters" + + +class IQLAggregationParseability(IQLParseability): + """ + IQL aggregation parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs. + """ + + prefix: str = "AGG" + iql: str = "aggregation" + + +class IQLCorrectness(Metric, ABC): + """ + IQL correctness is proportion of IQLs that produce correct results out of all parseable IQLs. """ + prefix: str + iql: str + def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: """ - Computes the IQL filters correctness. + Computes the IQL correctness. Args: results: List of evaluation results. Returns: - IQL filters correctness. + IQL correctness. """ results = [ result for result in results if (result.reference.iql and result.prediction.iql) and ( - result.reference.iql.filters.source - and result.prediction.iql.filters.source - and result.prediction.iql.filters.valid + getattr(result.reference.iql, self.iql).source + and getattr(result.prediction.iql, self.iql).source + and getattr(result.prediction.iql, self.iql).valid ) ] return { - "IQL/FLT/CORRECTNESS": ( - sum(result.prediction.iql.filters.source == result.reference.iql.filters.source for result in results) + f"IQL/{self.prefix}/CORRECTNESS": ( + sum( + getattr(result.prediction.iql, self.iql).source == getattr(result.reference.iql, self.iql).source + for result in results + ) / len(results) if results else None ) } + + +class IQLFiltersCorrectness(IQLCorrectness): + """ + IQL filters correctness is proportion of IQLs that produce correct results out of all parseable IQLs. + """ + + prefix: str = "FLT" + iql: str = "filters" + + +class IQLAggregationCorrectness(IQLCorrectness): + """ + IQL aggregation correctness is proportion of IQLs that produce correct results out of all parseable IQLs. + """ + + prefix: str = "AGG" + iql: str = "aggregation" diff --git a/benchmarks/sql/bench/views/structured/superhero.py b/benchmarks/sql/bench/views/structured/superhero.py index 305369bb..75a86ca7 100644 --- a/benchmarks/sql/bench/views/structured/superhero.py +++ b/benchmarks/sql/bench/views/structured/superhero.py @@ -293,7 +293,7 @@ def count_superheroes(self) -> Select: Returns: The superheros count. """ - return self.data.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) + return self.select.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) @view_aggregation() def average_height(self) -> Select: @@ -303,7 +303,9 @@ def average_height(self) -> Select: Returns: The superheros average height. """ - return self.data.with_only_columns(func.avg(Superhero.height_cm).label("average_height")).group_by(Superhero.id) + return self.select.with_only_columns(func.avg(Superhero.height_cm).label("average_height")).group_by( + Superhero.id + ) class SuperheroColourFilterMixin: @@ -387,7 +389,7 @@ def percentage_of_eye_colour(self, eye_colour: str) -> Select: Returns: The percentage of objects with eye colour. """ - return self.data.with_only_columns( + return self.select.with_only_columns( ( cast(func.count(case((self.eye_colour.colour == eye_colour, Superhero.id), else_=None)), Float) * 100 @@ -431,7 +433,7 @@ def percentage_of_publisher(self, publisher_name: str) -> Select: Returns: The percentage of objects with publisher. """ - return self.data.with_only_columns( + return self.select.with_only_columns( ( cast(func.count(case((Publisher.publisher_name == publisher_name, Superhero.id), else_=None)), Float) * 100 @@ -475,7 +477,7 @@ def percentage_of_alignment(self, alignment: Literal["Good", "Bad", "Neutral", " Returns: The percentage of objects with alignment. """ - return self.data.with_only_columns( + return self.select.with_only_columns( ( cast(func.count(case((Alignment.alignment == alignment, Superhero.id), else_=None)), Float) * 100 @@ -519,7 +521,7 @@ def percentage_of_gender(self, gender: Literal["Male", "Female", "N/A"]) -> Sele Returns: The percentage of objects with gender. """ - return self.data.with_only_columns( + return self.select.with_only_columns( ( cast(func.count(case((Gender.gender == gender, Superhero.id), else_=None)), Float) * 100