From ef0b0be51fa4344913dee60bce48beec91099a26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Tue, 30 Jul 2024 03:40:53 +0200 Subject: [PATCH] refactor --- benchmarks/sql/README.md | 10 +- benchmarks/sql/bench.py | 52 ++- benchmarks/sql/bench/evaluator.py | 3 +- benchmarks/sql/bench/loaders.py | 8 +- benchmarks/sql/bench/metrics/__init__.py | 32 +- benchmarks/sql/bench/metrics/iql.py | 216 +++++++-- benchmarks/sql/bench/metrics/selector.py | 60 ++- benchmarks/sql/bench/metrics/sql.py | 36 +- benchmarks/sql/bench/pipelines/base.py | 88 +--- benchmarks/sql/bench/pipelines/collection.py | 87 ++-- benchmarks/sql/bench/pipelines/view.py | 164 ++++--- benchmarks/sql/bench/views/__init__.py | 4 +- .../sql/bench/views/structured/superhero.py | 438 +----------------- benchmarks/sql/config/data/superhero.yaml | 3 +- benchmarks/sql/config/setup/collection.yaml | 9 +- benchmarks/sql/config/setup/iql-view.yaml | 8 +- benchmarks/sql/config/setup/sql-view.yaml | 3 +- .../setup/views/freeform/superhero.yaml | 1 + .../setup/views/structured/superhero.yaml | 4 + benchmarks/sql/tests/test_evaluator.py | 8 +- benchmarks/sql/tests/test_metrics.py | 245 +++++++++- setup.cfg | 7 - 22 files changed, 757 insertions(+), 729 deletions(-) create mode 100644 benchmarks/sql/config/setup/views/freeform/superhero.yaml create mode 100644 benchmarks/sql/config/setup/views/structured/superhero.yaml diff --git a/benchmarks/sql/README.md b/benchmarks/sql/README.md index bb49dc80..0b7da2d9 100644 --- a/benchmarks/sql/README.md +++ b/benchmarks/sql/README.md @@ -17,7 +17,13 @@ Before starting, download the `superhero.sqlite` database file from [BIRD](https Run the whole suite on the `superhero` database with `gpt-3.5-turbo`: ```bash -python bench.py --multirun setup=iql-view,sql-view,collection data=superhero +python bench.py --multirun setup=iql-view,sql-view,collection +``` + +Run on multiple databases: + +```bash +python bench.py setup=sql-view setup/views/freeform@setup.views='[superhero,...]' data=bird ``` You can also run each evaluation separately or in subgroups: @@ -34,7 +40,7 @@ python bench.py --multirun setup=iql-view setup/llm=gpt-3.5-turbo,claude-3.5-son python bench.py --multirun setup=sql-view setup/llm=gpt-3.5-turbo,claude-3.5-sonnet ``` -For the `collection` steup, you need to specify models for both the view selection and the IQL generation step: +For the `collection` setup, you need to specify models for both the view selection and the IQL generation step: ```bash python bench.py --multirun \ diff --git a/benchmarks/sql/bench.py b/benchmarks/sql/bench.py index d2668e88..a86c6ee4 100644 --- a/benchmarks/sql/bench.py +++ b/benchmarks/sql/bench.py @@ -8,15 +8,20 @@ from bench.evaluator import Evaluator from bench.loaders import CollectionDataLoader, IQLViewDataLoader, SQLViewDataLoader from bench.metrics import ( - ExactMatchAggregationIQL, - ExactMatchFiltersIQL, - ExactMatchIQL, - ExactMatchSQL, ExecutionAccuracy, + FilteringAccuracy, + FilteringPrecision, + FilteringRecall, + IQLFiltersAccuracy, + IQLFiltersCorrectness, + IQLFiltersParseability, + IQLFiltersPrecision, + IQLFiltersRecall, MetricSet, - UnsupportedIQL, - ValidIQL, + SQLExactMatch, ViewSelectionAccuracy, + ViewSelectionPrecision, + ViewSelectionRecall, ) from bench.pipelines import CollectionEvaluationPipeline, IQLViewEvaluationPipeline, SQLViewEvaluationPipeline from bench.utils import save @@ -52,26 +57,33 @@ class EvaluationType(Enum): EVALUATION_METRICS = { EvaluationType.IQL.value: MetricSet( - ExactMatchIQL, - ExactMatchFiltersIQL, - ExactMatchAggregationIQL, - ValidIQL, - ViewSelectionAccuracy, - UnsupportedIQL, + FilteringAccuracy, + FilteringPrecision, + FilteringRecall, + IQLFiltersAccuracy, + IQLFiltersPrecision, + IQLFiltersRecall, + IQLFiltersParseability, + IQLFiltersCorrectness, ExecutionAccuracy, ), EvaluationType.SQL.value: MetricSet( - ExactMatchSQL, + SQLExactMatch, ExecutionAccuracy, ), EvaluationType.E2E.value: MetricSet( - ExactMatchIQL, - ExactMatchFiltersIQL, - ExactMatchAggregationIQL, - ValidIQL, - UnsupportedIQL, + FilteringAccuracy, + FilteringPrecision, + FilteringRecall, + IQLFiltersAccuracy, + IQLFiltersPrecision, + IQLFiltersRecall, + IQLFiltersParseability, + IQLFiltersCorrectness, ViewSelectionAccuracy, - ExactMatchSQL, + ViewSelectionPrecision, + ViewSelectionRecall, + SQLExactMatch, ExecutionAccuracy, ), } @@ -113,7 +125,7 @@ async def bench(config: DictConfig) -> None: run["sys/tags"].add( [ config.setup.name, - config.data.db_id, + *config.data.db_ids, *config.data.difficulties, ] ) diff --git a/benchmarks/sql/bench/evaluator.py b/benchmarks/sql/bench/evaluator.py index 2a5a201d..5903732f 100644 --- a/benchmarks/sql/bench/evaluator.py +++ b/benchmarks/sql/bench/evaluator.py @@ -1,4 +1,5 @@ import time +from dataclasses import asdict from typing import Any, Callable, Dict, List, Tuple from datasets import Dataset @@ -81,7 +82,7 @@ def _results_processor(self, results: List[EvaluationResult]) -> Dict[str, Any]: Returns: The processed results. """ - return {"results": [result.dict() for result in results]} + return {"results": [asdict(result) for result in results]} def _compute_metrics(self, metrics: MetricSet, results: List[EvaluationResult]) -> Dict[str, Any]: """ diff --git a/benchmarks/sql/bench/loaders.py b/benchmarks/sql/bench/loaders.py index 8ec23be3..8e5d8387 100644 --- a/benchmarks/sql/bench/loaders.py +++ b/benchmarks/sql/bench/loaders.py @@ -54,9 +54,9 @@ async def load(self) -> Dataset: """ dataset = await super().load() return dataset.filter( - lambda x: x["db_id"] == self.config.data.db_id + lambda x: x["db_id"] in self.config.data.db_ids and x["difficulty"] in self.config.data.difficulties - and x["view"] is not None + and x["view_name"] is not None ) @@ -74,7 +74,7 @@ async def load(self) -> Dataset: """ dataset = await super().load() return dataset.filter( - lambda x: x["db_id"] == self.config.data.db_id and x["difficulty"] in self.config.data.difficulties + lambda x: x["db_id"] in self.config.data.db_ids and x["difficulty"] in self.config.data.difficulties ) @@ -92,5 +92,5 @@ async def load(self) -> Dataset: """ dataset = await super().load() return dataset.filter( - lambda x: x["db_id"] == self.config.data.db_id and x["difficulty"] in self.config.data.difficulties + lambda x: x["db_id"] in self.config.data.db_ids and x["difficulty"] in self.config.data.difficulties ) diff --git a/benchmarks/sql/bench/metrics/__init__.py b/benchmarks/sql/bench/metrics/__init__.py index 86d72f78..f0edc124 100644 --- a/benchmarks/sql/bench/metrics/__init__.py +++ b/benchmarks/sql/bench/metrics/__init__.py @@ -1,17 +1,31 @@ from .base import Metric, MetricSet -from .iql import ExactMatchAggregationIQL, ExactMatchFiltersIQL, ExactMatchIQL, UnsupportedIQL, ValidIQL -from .selector import ViewSelectionAccuracy -from .sql import ExactMatchSQL, ExecutionAccuracy +from .iql import ( + FilteringAccuracy, + FilteringPrecision, + FilteringRecall, + IQLFiltersAccuracy, + IQLFiltersCorrectness, + IQLFiltersParseability, + IQLFiltersPrecision, + IQLFiltersRecall, +) +from .selector import ViewSelectionAccuracy, ViewSelectionPrecision, ViewSelectionRecall +from .sql import ExecutionAccuracy, SQLExactMatch __all__ = [ "Metric", "MetricSet", - "ExactMatchSQL", - "ExactMatchIQL", - "ExactMatchFiltersIQL", - "ExactMatchAggregationIQL", - "ValidIQL", + "FilteringAccuracy", + "FilteringPrecision", + "FilteringRecall", + "IQLFiltersAccuracy", + "IQLFiltersPrecision", + "IQLFiltersRecall", + "IQLFiltersParseability", + "IQLFiltersCorrectness", + "SQLExactMatch", "ViewSelectionAccuracy", - "UnsupportedIQL", + "ViewSelectionPrecision", + "ViewSelectionRecall", "ExecutionAccuracy", ] diff --git a/benchmarks/sql/bench/metrics/iql.py b/benchmarks/sql/bench/metrics/iql.py index c9339f86..d7068ea2 100644 --- a/benchmarks/sql/bench/metrics/iql.py +++ b/benchmarks/sql/bench/metrics/iql.py @@ -1,15 +1,112 @@ from typing import Any, Dict, List -from dbally.iql._exceptions import IQLError -from dbally.iql_generator.prompt import UnsupportedQueryError - from ..pipelines import EvaluationResult from .base import Metric -class ExactMatchIQL(Metric): +class FilteringAccuracy(Metric): + """ + Filtering accuracy indicating proportion of questions that were correctly identified as having filters. + """ + + def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: + """ + Computes the filtering accuracy. + + Args: + results: List of evaluation results. + + Returns: + Filtering accuracy. + """ + results = [result for result in results if result.reference.iql and result.prediction.iql] + return { + "DM/FLT/ACC": ( + sum( + isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source)) + and result.prediction.iql.filters.unsupported == result.reference.iql.filters.unsupported + for result in results + ) + / len(results) + if results + else None + ) + } + + +class FilteringPrecision(Metric): + """ + Filtering precision indicating proportion of questions that were identified as having filters correctly. + """ + + def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: + """ + Computes the filtering precision. + + Args: + results: List of evaluation results. + + Returns: + Filtering precision. + """ + results = [ + result + for result in results + if (result.reference.iql and result.prediction.iql) + and (result.prediction.iql.filters.source or result.prediction.iql.filters.unsupported) + ] + return { + "DM/FLT/PRECISION": ( + sum( + isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source)) + and result.prediction.iql.filters.unsupported == result.reference.iql.filters.unsupported + for result in results + ) + / len(results) + if results + else None + ) + } + + +class FilteringRecall(Metric): """ - Ratio of predicated queries that are identical to the ground truth ones. + Filtering recall indicating proportion of questions that were correctly identified as having filters. + """ + + def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: + """ + Computes the filtering recall. + + Args: + results: List of evaluation results. + + Returns: + Filtering recall. + """ + results = [ + result + for result in results + if (result.reference.iql and result.prediction.iql) + and (result.reference.iql.filters.source or result.reference.iql.filters.unsupported) + ] + return { + "DM/FLT/RECALL": ( + sum( + isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source)) + and result.prediction.iql.filters.unsupported == result.reference.iql.filters.unsupported + for result in results + ) + / len(results) + if results + else None + ) + } + + +class IQLFiltersAccuracy(Metric): + """ + Ratio of predicated IQL filters that are identical to the ground truth ones. """ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: @@ -22,19 +119,33 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: Returns: Ratio of predicated queries that are identical to the ground truth ones. """ - results = [result for result in results if result.prediction.iql is not None] + results = [ + result + for result in results + if (result.reference.iql and result.prediction.iql) + and ( + result.reference.iql.filters.source + or result.reference.iql.filters.unsupported + and result.prediction.iql.filters.source + or result.prediction.iql.filters.unsupported + ) + ] return { - "EM_IQL": ( - sum(result.prediction.iql == result.reference.iql for result in results) / len(results) + "IQL/FLT/ACC": ( + sum( + isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source)) + for result in results + ) + / len(results) if results else None ) } -class ExactMatchFiltersIQL(Metric): +class IQLFiltersPrecision(Metric): """ - Ration of predicated IQL filters that are identical to the ground truth ones. + Ratio of predicated IQL filters that are identical to the ground truth ones. """ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: @@ -47,19 +158,32 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: Returns: Ratio of predicated queries that are identical to the ground truth ones. """ - results = [result for result in results if result.prediction.iql is not None] + results = [ + result + for result in results + if (result.reference.iql and result.prediction.iql) + and ( + result.reference.iql.filters.source + or result.reference.iql.filters.unsupported + and result.prediction.iql.filters.source + ) + ] return { - "EM_FLT_IQL": ( - sum(result.prediction.iql.filters == result.reference.iql.filters for result in results) / len(results) + "IQL/FLT/PRECISION": ( + sum( + isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source)) + for result in results + ) + / len(results) if results else None ) } -class ExactMatchAggregationIQL(Metric): +class IQLFiltersRecall(Metric): """ - Ratio of predicated aggregation that are identical to the ground truth ones. + Ratio of predicated IQL filters that are identical to the ground truth ones. """ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: @@ -72,10 +196,22 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: Returns: Ratio of predicated queries that are identical to the ground truth ones. """ - results = [result for result in results if result.prediction.iql is not None] + results = [ + result + for result in results + if (result.reference.iql and result.prediction.iql) + and ( + result.reference.iql.filters.source + and result.prediction.iql.filters.source + or result.prediction.iql.filters.unsupported + ) + ] return { - "EM_AGG_IQL": ( - sum(result.prediction.iql.aggregation == result.reference.iql.aggregation for result in results) + "IQL/FLT/RECALL": ( + sum( + isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source)) + for result in results + ) / len(results) if results else None @@ -83,55 +219,65 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: } -class UnsupportedIQL(Metric): +class IQLFiltersParseability(Metric): """ - Ratio of unsupported IQL queries. + Ratio of predicated IQL filters that are identical to the ground truth ones. """ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: """ - Calculates the unsupported IQL ratio. + Computes the exact match ratio. Args: results: List of evaluation results. Returns: - Unsupported queries ratio. + Ratio of predicated queries that are identical to the ground truth ones. """ results = [ result for result in results - if result.prediction.iql is not None or isinstance(result.prediction.exception, UnsupportedQueryError) + if (result.reference.iql and result.prediction.iql) + and (result.reference.iql.filters and result.prediction.iql.filters) + and (result.reference.iql.filters.source and result.prediction.iql.filters.source) ] return { - "UNSUPP_IQL": ( - sum(isinstance(result.prediction.exception, UnsupportedQueryError) for result in results) / len(results) - if results - else 0.0 + "IQL/FLT/PARSEABILITY": ( + sum(result.prediction.iql.filters.valid for result in results) / len(results) if results else None ) } -class ValidIQL(Metric): +class IQLFiltersCorrectness(Metric): """ - Ratio of valid IQL queries. + Ratio of predicated IQL filters that are identical to the ground truth ones. """ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: """ - Calculates the valid IQL ratio. + Computes the exact match ratio. Args: results: List of evaluation results. Returns: - Valid IQL queries ratio. + Ratio of predicated queries that are identical to the ground truth ones. """ - results = [result for result in results if result.prediction.iql is not None] + results = [ + result + for result in results + if (result.reference.iql and result.prediction.iql) + and ( + result.reference.iql.filters.source + and result.prediction.iql.filters.source + and result.prediction.iql.filters.valid + ) + ] return { - "VAL_IQL": ( - sum(not isinstance(result.prediction.exception, IQLError) for result in results) / len(results) + "IQL/FLT/CORRECTNESS": ( + sum(result.prediction.iql.filters.source == result.reference.iql.filters.source for result in results) + / len(results) if results - else 0.0 + else None ) } diff --git a/benchmarks/sql/bench/metrics/selector.py b/benchmarks/sql/bench/metrics/selector.py index 66c8ab3b..42b20ef8 100644 --- a/benchmarks/sql/bench/metrics/selector.py +++ b/benchmarks/sql/bench/metrics/selector.py @@ -20,8 +20,64 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: Ratio of predicated queries that are identical to the ground truth ones. """ return { - "ACC_VIEW": ( - sum(result.prediction.view == result.reference.view for result in results) / len(results) + "VIEW/ACC": ( + sum(result.reference.view_name == result.prediction.view_name for result in results) / len(results) + if results + else None + ) + } + + +class ViewSelectionPrecision(Metric): + """ + Ratio of predicated queries that are identical to the ground truth ones. + """ + + def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: + """ + Computes the exact match ratio. + + Args: + results: List of evaluation results. + + Returns: + Ratio of predicated queries that are identical to the ground truth ones. + """ + results = [result for result in results if result.prediction.view_name] + return { + "VIEW/PRECISION": ( + sum(result.prediction.view_name == result.reference.view_name for result in results) / len(results) + if results + else None + ) + } + + +class ViewSelectionRecall(Metric): + """ + Ratio of predicated queries that are identical to the ground truth ones. + """ + + def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: + """ + Computes the exact match ratio. + + Args: + results: List of evaluation results. + + Returns: + Ratio of predicated queries that are identical to the ground truth ones. + """ + results = [ + result + for result in results + if result.prediction.view_name is None + and result.reference.view_name + or result.prediction.view_name == result.reference.view_name + ] + return { + "VIEW/RECALL": ( + sum(result.prediction.view_name == result.reference.view_name for result in results) / len(results) if results else None ) diff --git a/benchmarks/sql/bench/metrics/sql.py b/benchmarks/sql/bench/metrics/sql.py index c8594455..0b5899e7 100644 --- a/benchmarks/sql/bench/metrics/sql.py +++ b/benchmarks/sql/bench/metrics/sql.py @@ -9,9 +9,10 @@ from .base import Metric -class ExactMatchSQL(Metric): +class SQLExactMatch(Metric): """ - Ratio of predicated queries that are identical to the ground truth ones. + Exact match ratio i.e. the proportion of examples in the evaluation set for which + the predicted SQL is identical to the ground truth SQL. """ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: @@ -22,10 +23,10 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: results: List of evaluation results. Returns: - Ratio of predicated queries that are identical to the ground truth ones. + The exact match ratio. """ return { - "EM_SQL": ( + "SQL/EM": ( sum(result.prediction.sql == result.reference.sql for result in results) / len(results) if results else 0.0 @@ -40,9 +41,9 @@ class _DBMixin: def __init__(self, config: Dict, *args: Any, **kwargs: Any) -> None: super().__init__(config, *args, **kwargs) - self.db = create_engine(config.data.db_url) + self.dbs = {db: create_engine(f"sqlite:///data/{db}.db") for db in config.data.db_ids} - def _execute_query(self, query: str) -> List[Dict[str, Any]]: + def _execute_query(self, query: str, db_id: str) -> List[Dict[str, Any]]: """ Execute the given query on the database. @@ -52,11 +53,11 @@ def _execute_query(self, query: str) -> List[Dict[str, Any]]: Returns: The query results. """ - with self.db.connect() as connection: + with self.dbs[db_id].connect() as connection: rows = connection.execute(text(query)).fetchall() return [dict(row._mapping) for row in rows] # pylint: disable=protected-access - def _avarage_execution_time(self, query: str, n: int = 100) -> float: + def _avarage_execution_time(self, query: str, db_id: str, n: int = 100) -> float: """ Execute the given query on the database n times and return the average execution time. @@ -70,7 +71,7 @@ def _avarage_execution_time(self, query: str, n: int = 100) -> float: total_time = 0 for _ in range(n): start_time = time.perf_counter() - self._execute_query(query) + self._execute_query(query, db_id) total_time += time.perf_counter() - start_time return total_time / n @@ -95,20 +96,19 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]: Execution accuracy score and valid efficiency score. """ accurate_results = [result for result in results if self._execution_accuracy(result)] - return { - "EX": len(accurate_results) / len(results) if results else 0.0, + "EX": len(accurate_results) / len(results) if results else None, "VES": sum( ( - self._avarage_execution_time(result.reference.sql) - / self._avarage_execution_time(result.prediction.sql) + self._avarage_execution_time(result.reference.sql, result.db_id) + / self._avarage_execution_time(result.prediction.sql, result.db_id) ) ** 0.5 for result in accurate_results ) / len(results) if results - else 0.0, + else None, } def _execution_accuracy(self, result: EvaluationResult) -> bool: @@ -125,13 +125,13 @@ def _execution_accuracy(self, result: EvaluationResult) -> bool: return False try: - result.reference.results = self._execute_query(result.reference.sql) - result.prediction.results = self._execute_query(result.prediction.sql) + ref_results = self._execute_query(result.reference.sql, result.db_id) + pred_results = self._execute_query(result.prediction.sql, result.db_id) except SQLAlchemyError: return False - reference = pd.DataFrame(result.reference.results) - prediction = pd.DataFrame(result.prediction.results) + reference = pd.DataFrame(ref_results) + prediction = pd.DataFrame(pred_results) # If filtering works correctly, the number of rows will be the same # TODO: Sometimes a different number of rows is okay, e.g. if df has aggregated values that are expanded in gt diff --git a/benchmarks/sql/bench/pipelines/base.py b/benchmarks/sql/bench/pipelines/base.py index 58ba5186..38bcb304 100644 --- a/benchmarks/sql/bench/pipelines/base.py +++ b/benchmarks/sql/bench/pipelines/base.py @@ -1,8 +1,6 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - -from sqlalchemy import create_engine +from dataclasses import dataclass +from typing import Any, Dict, Optional from dbally.llms.base import LLM from dbally.llms.litellm import LiteLLM @@ -10,37 +8,25 @@ @dataclass -class IQLResult: +class IQL: """ - Represents the IQL result. + Represents the IQL. """ - filters: Optional[str] = None - aggregation: Optional[str] = None - - def __eq__(self, other: "IQLResult") -> bool: - """ - Compares two IQL results. - - Args: - other: The other IQL result to compare. + source: Optional[str] = None + unsupported: bool = False + valid: bool = True - Returns: - True if the two IQL results are equal, False otherwise. - """ - return self.filters == other.filters and self.aggregation == other.aggregation - def dict(self) -> Dict[str, Any]: - """ - Returns the dictionary representation of the object. +@dataclass +class IQLResult: + """ + Represents the result of an IQL query execution. + """ - Returns: - The dictionary representation. - """ - return { - "filters": self.filters, - "aggregation": self.aggregation, - } + filters: IQL + aggregation: IQL + context: bool = False @dataclass @@ -49,26 +35,9 @@ class ExecutionResult: Represents the result of a single query execution. """ - view: Optional[str] = None - sql: Optional[str] = None + view_name: Optional[str] = None iql: Optional[IQLResult] = None - results: List[Dict[str, Any]] = field(default_factory=list) - exception: Optional[Exception] = None - execution_time: Optional[float] = None - - def dict(self) -> Dict[str, Any]: - """ - Returns the dictionary representation of the object. - - Returns: - The dictionary representation. - """ - return { - "view": self.view, - "iql": self.iql.dict() if self.iql else None, - "sql": self.sql, - "len_results": len(self.results), - } + sql: Optional[str] = None @dataclass @@ -77,38 +46,17 @@ class EvaluationResult: Represents the result of a single evaluation. """ + db_id: str question: str reference: ExecutionResult prediction: ExecutionResult - def dict(self) -> Dict[str, Any]: - """ - Returns the dictionary representation of the object. - - Returns: - The dictionary representation. - """ - return { - "question": self.question, - "reference": self.reference.dict(), - "prediction": self.prediction.dict(), - } - class EvaluationPipeline(ABC): """ Collection evaluation pipeline. """ - def __init__(self, config: Dict) -> None: - """ - Constructs the pipeline for evaluating IQL predictions. - - Args: - config: The configuration for the pipeline. - """ - self.db = create_engine(config.data.db_url) - def get_llm(self, config: Dict) -> LLM: """ Returns the LLM based on the configuration. diff --git a/benchmarks/sql/bench/pipelines/collection.py b/benchmarks/sql/bench/pipelines/collection.py index 1efe9b9c..918cfbd9 100644 --- a/benchmarks/sql/bench/pipelines/collection.py +++ b/benchmarks/sql/bench/pipelines/collection.py @@ -1,13 +1,17 @@ from typing import Any, Dict +from sqlalchemy import create_engine + import dbally from dbally.collection.collection import Collection +from dbally.collection.exceptions import NoViewFoundError from dbally.iql._exceptions import IQLError from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.view_selection.llm_view_selector import LLMViewSelector +from dbally.views.structured import IQLGenerationError from ..views import VIEWS_REGISTRY -from .base import EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult +from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult class CollectionEvaluationPipeline(EvaluationPipeline): @@ -22,7 +26,6 @@ def __init__(self, config: Dict) -> None: Args: config: The configuration for the pipeline. """ - super().__init__(config) self.collection = self.get_collection(config.setup) def get_collection(self, config: Dict) -> Collection: @@ -46,20 +49,11 @@ def get_collection(self, config: Dict) -> Collection: ) collection.n_retries = 0 - for view_name in config.views: - view_cls = VIEWS_REGISTRY[view_name] - collection.add(view_cls, lambda: view_cls(self.db)) # pylint: disable=cell-var-from-loop - - if config.fallback: - fallback = dbally.create_collection( - name=config.fallback, - llm=generator_llm, - view_selector=view_selector, - ) - fallback.n_retries = 0 - fallback_cls = VIEWS_REGISTRY[config.fallback] - fallback.add(fallback_cls, lambda: fallback_cls(self.db)) - collection.set_fallback(fallback) + for db_name, view_names in config.views.items(): + db = create_engine(f"sqlite:///data/{db_name}.db") + for view_name in view_names: + view_cls = VIEWS_REGISTRY[view_name] + collection.add(view_cls, lambda: view_cls(db)) # pylint: disable=cell-var-from-loop return collection @@ -79,32 +73,67 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: dry_run=True, return_natural_response=False, ) - # TODO: Refactor exception handling for IQLError for filters and aggregation - except IQLError as exc: + except NoViewFoundError: prediction = ExecutionResult( - iql=IQLResult(filters=exc.source), - exception=exc, + view_name=None, + iql=None, + sql=None, + ) + except IQLGenerationError as exc: + prediction = ExecutionResult( + view_name=exc.view_name, + iql=IQLResult( + filters=IQL( + source=exc.filters, + unsupported=isinstance(exc.__cause__, UnsupportedQueryError), + valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)), + ), + aggregation=IQL( + source=exc.aggregation, + unsupported=isinstance(exc.__cause__, UnsupportedQueryError), + valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)), + ), + ), + sql=None, ) - # TODO: Remove this broad exception handling once the Text2SQL view is fixed - except (UnsupportedQueryError, Exception) as exc: # pylint: disable=broad-except - prediction = ExecutionResult(exception=exc) else: - iql = IQLResult(filters=result.context["iql"]) if "iql" in result.context else None prediction = ExecutionResult( - view=result.view_name, - iql=iql, + view_name=result.view_name, + iql=IQLResult( + filters=IQL( + source=result.context.get("iql"), + unsupported=False, + valid=True, + ), + aggregation=IQL( + source=None, + unsupported=False, + valid=True, + ), + ), sql=result.context.get("sql"), ) reference = ExecutionResult( - view=data["view"], + view_name=data["view_name"], iql=IQLResult( - filters=data["iql_filters"], - aggregation=data["iql_aggregation"], + filters=IQL( + source=data["iql_filters"], + unsupported=data["iql_filters_unsupported"], + valid=True, + ), + aggregation=IQL( + source=data["iql_aggregation"], + unsupported=data["iql_aggregation_unsupported"], + valid=True, + ), + context=data["iql_context"], ), sql=data["sql"], ) + return EvaluationResult( + db_id=data["db_id"], question=data["question"], reference=reference, prediction=prediction, diff --git a/benchmarks/sql/bench/pipelines/view.py b/benchmarks/sql/bench/pipelines/view.py index f0108b42..37969365 100644 --- a/benchmarks/sql/bench/pipelines/view.py +++ b/benchmarks/sql/bench/pipelines/view.py @@ -1,18 +1,23 @@ -from abc import ABC +# pylint: disable=duplicate-code + +from abc import ABC, abstractmethod from typing import Any, Dict, Type +from sqlalchemy import create_engine + from dbally.iql._exceptions import IQLError from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.views.freeform.text2sql.view import BaseText2SQLView from dbally.views.sqlalchemy_base import SqlAlchemyBaseView +from dbally.views.structured import IQLGenerationError from ..views import VIEWS_REGISTRY -from .base import EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult +from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult class ViewEvaluationPipeline(EvaluationPipeline, ABC): """ - Collection evaluation pipeline. + View evaluation pipeline. """ def __init__(self, config: Dict) -> None: @@ -22,36 +27,53 @@ def __init__(self, config: Dict) -> None: Args: config: The configuration for the pipeline. """ - super().__init__(config) self.llm = self.get_llm(config.setup.llm) + self.dbs = self.get_dbs(config.setup) + self.views = self.get_views(config.setup) + def get_dbs(self, config: Dict) -> Dict: + """ + Returns the database object based on the database name. -class IQLViewEvaluationPipeline(ViewEvaluationPipeline): - """ - Collection evaluation pipeline. - """ + Args: + config: The database configuration. - def __init__(self, config: Dict) -> None: + Returns: + The database object. """ - Constructs the pipeline for evaluating IQL predictions. + return {db: create_engine(f"sqlite:///data/{db}.db") for db in config.views} + + @abstractmethod + def get_views(self, config: Dict) -> Dict[str, Type[SqlAlchemyBaseView]]: + """ + Creates the view classes mapping based on the configuration. Args: - config: The configuration for the pipeline. + config: The views configuration. + + Returns: + The view classes mapping. """ - super().__init__(config) - self.views = self.get_views(config.setup) + + +class IQLViewEvaluationPipeline(ViewEvaluationPipeline): + """ + IQL view evaluation pipeline. + """ def get_views(self, config: Dict) -> Dict[str, Type[SqlAlchemyBaseView]]: """ - Returns the view object based on the view name. + Creates the view classes mapping based on the configuration. Args: - config: The view configuration. + config: The views configuration. Returns: - The view object. + The view classes mapping. """ - return {view: VIEWS_REGISTRY[view] for view in config.views} + return { + view_name: VIEWS_REGISTRY[view_name] for view_names in config.views.values() for view_name in view_names + } async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: """ @@ -63,7 +85,8 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: Returns: The evaluation result. """ - view = self.views[data["view"]](self.db) + view = self.views[data["view_name"]](self.dbs[data["db_id"]]) + try: result = await view.ask( query=data["question"], @@ -71,34 +94,61 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: dry_run=True, n_retries=0, ) - # TODO: Refactor exception handling for IQLError for filters and aggregation - except IQLError as exc: + except IQLGenerationError as exc: prediction = ExecutionResult( - view=data["view"], - iql=IQLResult(filters=exc.source), - exception=exc, - ) - except (UnsupportedQueryError, Exception) as exc: # pylint: disable=broad-except - prediction = ExecutionResult( - view=data["view"], - exception=exc, + view_name=data["view_name"], + iql=IQLResult( + filters=IQL( + source=exc.filters, + unsupported=isinstance(exc.__cause__, UnsupportedQueryError), + valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)), + ), + aggregation=IQL( + source=exc.aggregation, + unsupported=isinstance(exc.__cause__, UnsupportedQueryError), + valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)), + ), + ), + sql=None, ) else: prediction = ExecutionResult( - view=data["view"], - iql=IQLResult(filters=result.context["iql"]), + view_name=data["view_name"], + iql=IQLResult( + filters=IQL( + source=result.context["iql"], + unsupported=False, + valid=True, + ), + aggregation=IQL( + source=None, + unsupported=False, + valid=True, + ), + ), sql=result.context["sql"], ) reference = ExecutionResult( - view=data["view"], + view_name=data["view_name"], iql=IQLResult( - filters=data["iql_filters"], - aggregation=data["iql_aggregation"], + filters=IQL( + source=data["iql_filters"], + unsupported=data["iql_filters_unsupported"], + valid=True, + ), + aggregation=IQL( + source=data["iql_aggregation"], + unsupported=data["iql_aggregation_unsupported"], + valid=True, + ), + context=data["iql_context"], ), sql=data["sql"], ) + return EvaluationResult( + db_id=data["db_id"], question=data["question"], reference=reference, prediction=prediction, @@ -107,30 +157,20 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: class SQLViewEvaluationPipeline(ViewEvaluationPipeline): """ - Collection evaluation pipeline. + SQL view evaluation pipeline. """ - def __init__(self, config: Dict) -> None: - """ - Constructs the pipeline for evaluating IQL predictions. - - Args: - config: The configuration for the pipeline. - """ - super().__init__(config) - self.view = self.get_view(config.setup) - - def get_view(self, config: Dict) -> Type[BaseText2SQLView]: + def get_views(self, config: Dict) -> Dict[str, Type[BaseText2SQLView]]: """ - Returns the view object based on the view name. + Creates the view classes mapping based on the configuration. Args: - config: The view configuration. + config: The views configuration. Returns: - The view object. + The view classes mapping. """ - return VIEWS_REGISTRY[config.view] + return {db_id: VIEWS_REGISTRY[view_name] for db_id, view_name in config.views.items()} async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: """ @@ -142,7 +182,8 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: Returns: The evaluation result. """ - view = self.view(self.db) + view = self.views[data["db_id"]](self.dbs[data["db_id"]]) + try: result = await view.ask( query=data["question"], @@ -151,26 +192,15 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: n_retries=0, ) # TODO: Remove this broad exception handling once the Text2SQL view is fixed - except Exception as exc: # pylint: disable=broad-except - prediction = ExecutionResult( - view=self.view.__name__, - exception=exc, - ) + except Exception: # pylint: disable=broad-except + prediction = ExecutionResult() else: - prediction = ExecutionResult( - view=self.view.__name__, - sql=result.context["sql"], - ) + prediction = ExecutionResult(sql=result.context["sql"]) + + reference = ExecutionResult(sql=data["sql"]) - reference = ExecutionResult( - view=data["view"], - iql=IQLResult( - filters=data["iql_filters"], - aggregation=data["iql_aggregation"], - ), - sql=data["sql"], - ) return EvaluationResult( + db_id=data["db_id"], question=data["question"], reference=reference, prediction=prediction, diff --git a/benchmarks/sql/bench/views/__init__.py b/benchmarks/sql/bench/views/__init__.py index 732779e2..9c7230e7 100644 --- a/benchmarks/sql/bench/views/__init__.py +++ b/benchmarks/sql/bench/views/__init__.py @@ -3,12 +3,10 @@ from dbally.views.base import BaseView from .freeform.superhero import SuperheroFreeformView -from .structured.superhero import HeroAttributeView, HeroPowerView, PublisherView, SuperheroView +from .structured.superhero import PublisherView, SuperheroView VIEWS_REGISTRY: Dict[str, Type[BaseView]] = { PublisherView.__name__: PublisherView, - HeroAttributeView.__name__: HeroAttributeView, - HeroPowerView.__name__: HeroPowerView, SuperheroView.__name__: SuperheroView, SuperheroFreeformView.__name__: SuperheroFreeformView, } diff --git a/benchmarks/sql/bench/views/structured/superhero.py b/benchmarks/sql/bench/views/structured/superhero.py index 76f9e290..db57498e 100644 --- a/benchmarks/sql/bench/views/structured/superhero.py +++ b/benchmarks/sql/bench/views/structured/superhero.py @@ -4,8 +4,8 @@ from typing import Literal from sqlalchemy import ColumnElement, Engine, Select, func, select -from sqlalchemy.ext.declarative import DeferredReflection, declarative_base -from sqlalchemy.orm import aliased +from sqlalchemy.ext.declarative import DeferredReflection +from sqlalchemy.orm import aliased, declarative_base from dbally.views.decorators import view_filter from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @@ -174,17 +174,18 @@ def filter_by_height_cm_greater_than(self, height_cm: float) -> ColumnElement: return Superhero.height_cm > height_cm @view_filter() - def filter_by_height_greater_than_percentage_of_average(self, average_percentage: int) -> ColumnElement: + def filter_by_height_cm_between(self, begin_height_cm: float, end_height_cm: float) -> ColumnElement: """ - Filters the view by the height greater than the percentage of average of superheroes. + Filters the view by the height of the superhero. Args: - average_percentage: The percentage of the average height. + begin_height_cm: The begin height of the superhero. + end_height_cm: The end height of the superhero. Returns: The filter condition. """ - return Superhero.height_cm * 100 > select(func.avg(Superhero.height_cm)).scalar_subquery() * average_percentage + return Superhero.height_cm.between(begin_height_cm, end_height_cm) @view_filter() def filter_by_the_tallest(self) -> ColumnElement: @@ -279,226 +280,6 @@ def filter_by_missing_publisher(self) -> ColumnElement: return Superhero.publisher_id == None -class SuperheroHeroPowerFilterMixin: - """ - Mixin for filtering the view by the superhero superpowers. - """ - - @view_filter() - def filter_by_number_powers(self, number_powers: int) -> ColumnElement: - """ - Filters the view by the number of superpowers. - - Args: - number_powers: The number of hero superpowers. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroPower.hero_id) - .group_by(HeroPower.hero_id) - .having(func.count(HeroPower.power_id) == number_powers) - ) - - @view_filter() - def filter_by_number_super_powers_greater_than(self, number_powers: int) -> ColumnElement: - """ - Filters the view by the number of superpowers. - - Args: - number_powers: The number of hero superpowers. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroPower.hero_id).group_by(HeroPower.hero_id).having(func.count(HeroPower.power_id) > number_powers) - ) - - @view_filter() - def filter_by_number_powers_less_than(self, number_powers: int) -> ColumnElement: - """ - Filters the view by the number of superpowers. - - Args: - number_powers: The number of hero superpowers. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroPower.hero_id).group_by(HeroPower.hero_id).having(func.count(HeroPower.power_id) < number_powers) - ) - - @view_filter() - def filter_by_power_name(self, power_name: str) -> ColumnElement: - """ - Filters the view by the superpower name. - - Args: - power_name: The name of the superpower. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroPower.hero_id) - .join(Superpower, Superpower.id == HeroPower.power_id) - .where(Superpower.power_name == power_name) - ) - - @view_filter() - def filter_by_the_most_super_powers(self) -> ColumnElement: - """ - Filters the view by the most superpowers. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroPower.hero_id) - .group_by(HeroPower.hero_id) - .order_by(func.count(HeroPower.power_id).desc()) - .limit(1) - ) - - -class SuperheroHeroAttributeFilterMixin: - """ - Mixin for filtering the view by the superhero attributes. - """ - - @view_filter() - def filter_by_attribute_name(self, attribute_name: str) -> ColumnElement: - """ - Filters the view by the hero attribute name. - - Args: - attribute_name: The name of the hero attribute. - - Returns: - The filter condition. - """ - return Superpower.id.in_( - select(HeroAttribute.hero_id) - .join(Attribute, Attribute.id == HeroAttribute.attribute_id) - .where(Attribute.attribute_name == attribute_name) - ) - - @view_filter() - def filter_by_attribute_value(self, attribute_value: int) -> ColumnElement: - """ - Filters the view by the hero attribute value. - - Args: - attribute_value: The value of the hero attribute. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroAttribute.hero_id) - .group_by(HeroAttribute.hero_id) - .having(HeroAttribute.attribute_value == attribute_value) - ) - - @view_filter() - def filter_by_the_lowest_attribute_value(self) -> ColumnElement: - """ - Filters the view by the lowest hero attribute value. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroAttribute.hero_id) - .group_by(HeroAttribute.hero_id) - .having(HeroAttribute.attribute_value == select(func.min(HeroAttribute.attribute_value)).scalar_subquery()) - ) - - @view_filter() - def filter_by_the_highest_attribute_value(self) -> ColumnElement: - """ - Filters the view by the highest hero attribute value. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroAttribute.hero_id) - .group_by(HeroAttribute.hero_id) - .having(HeroAttribute.attribute_value == select(func.max(HeroAttribute.attribute_value)).scalar_subquery()) - ) - - @view_filter() - def filter_by_attribute_value_less_than(self, attribute_value: int) -> ColumnElement: - """ - Filters the view by the hero attribute value. - - Args: - attribute_value: The value of the hero attribute. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroAttribute.hero_id) - .group_by(HeroAttribute.hero_id) - .having(func.min(HeroAttribute.attribute_value) < attribute_value) - ) - - @view_filter() - def filter_by_attribute_value_between(self, begin_attribute_value: int, end_attribute_value: int) -> ColumnElement: - """ - Filters the view by the hero attribute value. - - Args: - begin_attribute_value: The begin value of the hero attribute. - end_attribute_value: The end value of the hero attribute. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroAttribute.hero_id) - .group_by(HeroAttribute.hero_id) - .having(HeroAttribute.attribute_value.between(begin_attribute_value, end_attribute_value)) - ) - - @view_filter() - def filter_by_the_fastest(self) -> ColumnElement: - """ - Filters the view by the fastest superhero. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroAttribute.hero_id) - .join(Attribute, Attribute.id == HeroAttribute.attribute_id) - .where(Attribute.attribute_name == "Speed") - .group_by(HeroAttribute.hero_id) - .having(HeroAttribute.attribute_value == select(func.max(HeroAttribute.attribute_value)).scalar_subquery()) - ) - - @view_filter() - def filter_by_the_dumbest(self) -> ColumnElement: - """ - Filters the view by the dumbest superhero. - - Returns: - The filter condition. - """ - return Superhero.id.in_( - select(HeroAttribute.hero_id) - .join(Attribute, Attribute.id == HeroAttribute.attribute_id) - .where(Attribute.attribute_name == "Intelligence") - .group_by(HeroAttribute.hero_id) - .having(HeroAttribute.attribute_value == select(func.min(HeroAttribute.attribute_value)).scalar_subquery()) - ) - - class SuperheroColourFilterMixin: """ Mixin for filtering the view by the superhero colour attributes. @@ -589,43 +370,6 @@ def filter_by_publisher_name(self, publisher_name: str) -> ColumnElement: return Publisher.publisher_name == publisher_name -class PublisherSuperheroMixin: - """ - Mixin for filtering the publisher view by superheros. - """ - - @view_filter() - def filter_by_superhero_name(self, superhero_name: str) -> ColumnElement: - """ - Filters the view by the superhero name. - - Args: - superhero_name: The name of the superhero. - - Returns: - The filter condition. - """ - return Publisher.id.in_(select(Superhero.publisher_id).where(Superhero.superhero_name == superhero_name)) - - @view_filter() - def filter_by_the_slowest_superhero(self) -> ColumnElement: - """ - Filters the view by the slowest superhero. - - Returns: - The filter condition. - """ - return Publisher.id.in_( - select(Superhero.publisher_id) - .join(HeroAttribute, HeroAttribute.hero_id == Superhero.id) - .join(Attribute, Attribute.id == HeroAttribute.attribute_id) - .where( - Attribute.attribute_name == "Speed", - HeroAttribute.attribute_value == select(func.min(HeroAttribute.attribute_value)).scalar_subquery(), - ) - ) - - class AlignmentFilterMixin: """ Mixin for filtering the view by the alignment attributes. @@ -645,44 +389,6 @@ def filter_by_alignment(self, alignment: Literal["Good", "Bad", "Neutral", "N/A" return Alignment.alignment == alignment -class SuperpowerFilterMixin: - """ - Mixin for filtering the view by the superpower attributes. - """ - - @view_filter() - def filter_by_power_name(self, power_name: str) -> ColumnElement: - """ - Filters the view by the superpower name. - - Args: - power_name: The name of the superpower. - - Returns: - The filter condition. - """ - return Superpower.power_name == power_name - - -class RaceFilterMixin: - """ - Mixin for filtering the view by the race. - """ - - @view_filter() - def filter_by_race(self, race: str) -> ColumnElement: - """ - Filters the view by the object race. - - Args: - race: The race of the object. - - Returns: - The filter condition. - """ - return Race.race == race - - class GenderFilterMixin: """ Mixin for filtering the view by the gender. @@ -702,73 +408,23 @@ def filter_by_gender(self, gender: Literal["Male", "Female", "N/A"]) -> ColumnEl return Gender.gender == gender -class HeroAttributeFilterMixin: - """ - Mixin for filtering the view by the hero attribute. - """ - - @view_filter() - def filter_by_the_lowest_attribute_value(self) -> ColumnElement: - """ - Filters the view by the lowest hero attribute value. - - Returns: - The filter condition. - """ - return HeroAttribute.attribute_value == select(func.min(HeroAttribute.attribute_value)).scalar_subquery() - - @view_filter() - def filter_by_the_highest_attribute_value(self) -> ColumnElement: - """ - Filters the view by the highest hero attribute value. - - Returns: - The filter condition. - """ - return HeroAttribute.attribute_value == select(func.max(HeroAttribute.attribute_value)).scalar_subquery() - - -class HeroPowerFilterMixin: - """ - Mixin for filtering the view by the hero power. - """ - - @view_filter() - def filter_by_the_most_popular_power(self) -> ColumnElement: - """ - Filters the view by the most popular hero power. - - Returns: - The filter condition. - """ - return HeroPower.power_id == ( - select(HeroPower.power_id) - .group_by(HeroPower.power_id) - .order_by(func.count(HeroPower.power_id).desc()) - .limit(1) - .scalar_subquery() - ) - - -class AttributeFilterMixin: +class RaceFilterMixin: """ - Mixin for filtering the view by the attribute. + Mixin for filtering the view by the race. """ @view_filter() - def filter_by_attribute_name( - self, attribute_name: Literal["Intelligence", "Strength", "Speed", "Durability", "Power", "Combat"] - ) -> ColumnElement: + def filter_by_race(self, race: str) -> ColumnElement: """ - Filters the view by the attribute name. + Filters the view by the object race. Args: - attribute_name: The name of the attribute. + race: The race of the object. Returns: The filter condition. """ - return Attribute.attribute_name == attribute_name + return Race.race == race class SuperheroView( @@ -776,11 +432,9 @@ class SuperheroView( SqlAlchemyBaseView, SuperheroFilterMixin, SuperheroColourFilterMixin, - SuperheroHeroPowerFilterMixin, - SuperheroHeroAttributeFilterMixin, - PublisherFilterMixin, AlignmentFilterMixin, GenderFilterMixin, + PublisherFilterMixin, RaceFilterMixin, ): """ @@ -802,79 +456,25 @@ def get_select(self) -> Select: Superhero.full_name, Superhero.height_cm, Superhero.weight_kg, - Publisher.publisher_name, + Alignment.alignment, Gender.gender, + Publisher.publisher_name, Race.race, - Alignment.alignment, self.eye_colour.colour.label("eye_colour"), self.hair_colour.colour.label("hair_colour"), self.skin_colour.colour.label("skin_colour"), ) + .join(Alignment, Alignment.id == Superhero.alignment_id) + .join(Gender, Gender.id == Superhero.gender_id) .join(Publisher, Publisher.id == Superhero.publisher_id) .join(Race, Race.id == Superhero.race_id) - .join(Gender, Gender.id == Superhero.gender_id) - .join(Alignment, Alignment.id == Superhero.alignment_id) .join(self.eye_colour, self.eye_colour.id == Superhero.eye_colour_id) .join(self.hair_colour, self.hair_colour.id == Superhero.hair_colour_id) .join(self.skin_colour, self.skin_colour.id == Superhero.skin_colour_id) ) -class HeroAttributeView( - DBInitMixin, - SqlAlchemyBaseView, - HeroAttributeFilterMixin, - AttributeFilterMixin, - SuperheroFilterMixin, - AlignmentFilterMixin, -): - """ - View for querying only hero attributes data. Contains the attribute name and attribute value. - """ - - def get_select(self) -> Select: - """ - Initializes the select object for the view. - - Returns: - The select object. - """ - return ( - select( - Attribute.attribute_name, - HeroAttribute.attribute_value, - ) - .join(Attribute, Attribute.id == HeroAttribute.attribute_id) - .join(Superhero, Superhero.id == HeroAttribute.hero_id) - .join(Alignment, Alignment.id == Superhero.alignment_id) - .join(Publisher, Publisher.id == Superhero.publisher_id) - ) - - -class HeroPowerView(DBInitMixin, SqlAlchemyBaseView, HeroPowerFilterMixin, SuperheroFilterMixin, SuperpowerFilterMixin): - """ - View for querying only hero super powers data. Contains the power id and power name. - """ - - def get_select(self) -> Select: - """ - Initializes the select object for the view. - - Returns: - The select object. - """ - return ( - select( - HeroPower.power_id, - Superpower.power_name, - ) - .join(Superhero, Superhero.id == HeroPower.hero_id) - .join(Superpower, Superpower.id == HeroPower.power_id) - .group_by(HeroPower.power_id) - ) - - -class PublisherView(DBInitMixin, SqlAlchemyBaseView, PublisherFilterMixin, PublisherSuperheroMixin): +class PublisherView(DBInitMixin, SqlAlchemyBaseView, PublisherFilterMixin): """ View for querying only publisher data. Contains the publisher id and publisher name. """ diff --git a/benchmarks/sql/config/data/superhero.yaml b/benchmarks/sql/config/data/superhero.yaml index bb556c46..23412721 100644 --- a/benchmarks/sql/config/data/superhero.yaml +++ b/benchmarks/sql/config/data/superhero.yaml @@ -1,5 +1,4 @@ path: "micpst/bird-iql" split: "dev" -db_id: "superhero" +db_ids: ["superhero"] difficulties: ["simple", "moderate", "challenging"] -db_url: "sqlite:///data/superhero.db" diff --git a/benchmarks/sql/config/setup/collection.yaml b/benchmarks/sql/config/setup/collection.yaml index 3a7073b0..2eafb34a 100644 --- a/benchmarks/sql/config/setup/collection.yaml +++ b/benchmarks/sql/config/setup/collection.yaml @@ -1,12 +1,7 @@ name: COLLECTION -views: [ - "HeroAttributeView", - "HeroPowerView", - "PublisherView", - "SuperheroView", -] -fallback: "SuperheroFreeformView" defaults: - llm@selector_llm: gpt-3.5-turbo - llm@generator_llm: gpt-3.5-turbo + - views/structured@views: + - superhero diff --git a/benchmarks/sql/config/setup/iql-view.yaml b/benchmarks/sql/config/setup/iql-view.yaml index 9b6bcdde..e652bc3b 100644 --- a/benchmarks/sql/config/setup/iql-view.yaml +++ b/benchmarks/sql/config/setup/iql-view.yaml @@ -1,10 +1,6 @@ name: IQL_VIEW -views: [ - "HeroAttributeView", - "HeroPowerView", - "PublisherView", - "SuperheroView", -] defaults: - llm: gpt-3.5-turbo + - views/structured@views: + - superhero diff --git a/benchmarks/sql/config/setup/sql-view.yaml b/benchmarks/sql/config/setup/sql-view.yaml index f501b0d8..e4e1f7d9 100644 --- a/benchmarks/sql/config/setup/sql-view.yaml +++ b/benchmarks/sql/config/setup/sql-view.yaml @@ -1,5 +1,6 @@ name: SQL_VIEW -view: SuperheroFreeformView defaults: - llm: gpt-3.5-turbo + - views/freeform@views: + - superhero diff --git a/benchmarks/sql/config/setup/views/freeform/superhero.yaml b/benchmarks/sql/config/setup/views/freeform/superhero.yaml new file mode 100644 index 00000000..aa0cf958 --- /dev/null +++ b/benchmarks/sql/config/setup/views/freeform/superhero.yaml @@ -0,0 +1 @@ +superhero: SuperheroFreeformView diff --git a/benchmarks/sql/config/setup/views/structured/superhero.yaml b/benchmarks/sql/config/setup/views/structured/superhero.yaml new file mode 100644 index 00000000..6497bf6c --- /dev/null +++ b/benchmarks/sql/config/setup/views/structured/superhero.yaml @@ -0,0 +1,4 @@ +superhero: [ + PublisherView, + SuperheroView, +] diff --git a/benchmarks/sql/tests/test_evaluator.py b/benchmarks/sql/tests/test_evaluator.py index 26adf003..ea328e8b 100644 --- a/benchmarks/sql/tests/test_evaluator.py +++ b/benchmarks/sql/tests/test_evaluator.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Dict, List import pytest @@ -15,9 +16,9 @@ def compute(self, results) -> Dict[str, float]: return {"accuracy": 0.95} +@dataclass class MockEvaluationResult: - def dict(self) -> Dict[str, str]: - return {"result": "processed_data"} + result: str = "processed_data" class MockEvaluationPipeline: @@ -52,7 +53,6 @@ async def test_call_pipeline() -> None: assert "total_time_in_seconds" in perf_results["time_perf"] -@pytest.mark.asyncio def test_results_processor() -> None: evaluator = Evaluator(task="test_task") results = [MockEvaluationResult(), MockEvaluationResult()] @@ -63,7 +63,6 @@ def test_results_processor() -> None: assert len(processed_results["results"]) == len(results) -@pytest.mark.asyncio def test_compute_metrics() -> None: evaluator = Evaluator(task="test_task") metrics = MockMetricSet() @@ -75,7 +74,6 @@ def test_compute_metrics() -> None: assert computed_metrics["metrics"]["accuracy"] == 0.95 -@pytest.mark.asyncio def test_compute_time_perf() -> None: evaluator = Evaluator(task="test_task") start_time = 0 diff --git a/benchmarks/sql/tests/test_metrics.py b/benchmarks/sql/tests/test_metrics.py index f26233e4..71396139 100644 --- a/benchmarks/sql/tests/test_metrics.py +++ b/benchmarks/sql/tests/test_metrics.py @@ -4,13 +4,24 @@ import pytest -from benchmarks.sql.bench.metrics import ExactMatchIQL, ExecutionAccuracy, ValidIQL -from benchmarks.sql.bench.pipelines import EvaluationResult, ExecutionResult +from benchmarks.sql.bench.metrics.iql import ( + FilteringAccuracy, + FilteringPrecision, + FilteringRecall, + IQLFiltersAccuracy, + IQLFiltersCorrectness, + IQLFiltersParseability, + IQLFiltersPrecision, + IQLFiltersRecall, +) +from benchmarks.sql.bench.metrics.sql import ExecutionAccuracy, SQLExactMatch +from benchmarks.sql.bench.pipelines import EvaluationResult, ExecutionResult, IQLResult +from benchmarks.sql.bench.pipelines.base import IQL @dataclass class MockDataConfig: - db_url: str = "sqlite:///:memory:" + db_ids: str = "db_id" @dataclass @@ -22,48 +33,238 @@ class MockConfig: def evaluation_results() -> List[EvaluationResult]: return [ EvaluationResult( + db_id="db_id", question="question1", - reference=ExecutionResult(iql="filter_by_column1(10)"), - prediction=ExecutionResult(iql="filter_by_column1(10)"), + reference=ExecutionResult( + iql=IQLResult( + filters=IQL( + source="filter_by_column1(10)", + unsupported=False, + valid=True, + ), + aggregation=IQL( + source=None, + unsupported=False, + valid=True, + ), + ), + sql="SELECT * FROM table WHERE column1 = 10", + ), + prediction=ExecutionResult( + sql="SELECT * FROM table WHERE column1 = 10", + ), ), EvaluationResult( + db_id="db_id", question="question2", - reference=ExecutionResult(iql="filter_by_column2(20)"), - prediction=ExecutionResult(iql="filter_by_column2(30)"), + reference=ExecutionResult( + iql=IQLResult( + filters=IQL( + source="filter_by_column2(20)", + unsupported=False, + valid=True, + ), + aggregation=IQL( + source=None, + unsupported=False, + valid=True, + ), + ), + sql="SELECT * FROM table WHERE column2 = 20", + ), + prediction=ExecutionResult( + iql=IQLResult( + filters=IQL( + source="filter_by_column2(20)", + unsupported=False, + valid=True, + ), + aggregation=IQL( + source=None, + unsupported=False, + valid=True, + ), + ), + sql="SELECT * FROM table WHERE column2 = 30", + ), ), EvaluationResult( + db_id="db_id", question="question3", - reference=ExecutionResult(iql="filter_by_column3('Test')"), - prediction=ExecutionResult(iql="filter_by_column3(30)"), + reference=ExecutionResult( + iql=IQLResult( + filters=IQL( + source="filter_by_column3('TEST')", + unsupported=False, + valid=True, + ), + aggregation=IQL( + source=None, + unsupported=False, + valid=True, + ), + ), + sql="SELECT * FROM table WHERE column3 = 'TEST'", + ), + prediction=ExecutionResult( + iql=IQLResult( + filters=IQL( + source="filter_by_column3('test')", + unsupported=False, + valid=True, + ), + aggregation=IQL( + source=None, + unsupported=False, + valid=True, + ), + ), + sql="SELECT * FROM table WHERE column3 = 'test'", + ), ), EvaluationResult( + db_id="db_id", question="question4", - reference=ExecutionResult(iql="filter_by_column4(40)"), - prediction=ExecutionResult(iql="filter_by_column4(40)"), + reference=ExecutionResult( + iql=IQLResult( + filters=IQL( + source=None, + unsupported=False, + valid=True, + ), + aggregation=IQL( + source=None, + unsupported=False, + valid=True, + ), + ), + sql="SELECT * FROM table WHERE column4 = 40", + ), + prediction=ExecutionResult( + iql=IQLResult( + filters=IQL( + source="filter_by_column4(40)", + unsupported=False, + valid=True, + ), + aggregation=IQL( + source=None, + unsupported=False, + valid=True, + ), + ), + sql="SELECT * FROM table WHERE column3 = 'TEST'", + ), + ), + EvaluationResult( + db_id="db_id", + question="question5", + reference=ExecutionResult( + iql=IQLResult( + filters=IQL( + source="filter_by_column5(50)", + unsupported=False, + valid=True, + ), + aggregation=IQL( + source=None, + unsupported=False, + valid=True, + ), + ), + sql="SELECT * FROM table WHERE column5 = 50", + ), + prediction=ExecutionResult( + iql=IQLResult( + filters=IQL( + source=None, + unsupported=True, + valid=True, + ), + aggregation=IQL( + source=None, + unsupported=False, + valid=True, + ), + ), + sql="SELECT * FROM table WHERE column5 = 50", + ), ), ] -def test_exact_match_iql(evaluation_results: List[EvaluationResult]) -> None: - metric = ExactMatchIQL() +def test_filtering_accuracy(evaluation_results: List[EvaluationResult]) -> None: + metric = FilteringAccuracy() + scores = metric.compute(evaluation_results) + assert scores["DM/FLT/ACC"] == 0.5 + + +def test_filtering_precision(evaluation_results: List[EvaluationResult]) -> None: + metric = FilteringPrecision() + scores = metric.compute(evaluation_results) + assert scores["DM/FLT/PRECISION"] == 0.5 + + +def test_filtering_recall(evaluation_results: List[EvaluationResult]) -> None: + metric = FilteringRecall() + scores = metric.compute(evaluation_results) + assert scores["DM/FLT/RECALL"] == 0.6666666666666666 + + +def test_iql_filters_accuracy(evaluation_results: List[EvaluationResult]) -> None: + metric = IQLFiltersAccuracy() + scores = metric.compute(evaluation_results) + assert scores["IQL/FLT/ACC"] == 0.6666666666666666 + + +def test_iql_filters_precision(evaluation_results: List[EvaluationResult]) -> None: + metric = IQLFiltersPrecision() scores = metric.compute(evaluation_results) - assert scores["EM_IQL"] == 0.5 + assert scores["IQL/FLT/PRECISION"] == 0.6666666666666666 -def test_valid_iql(evaluation_results) -> None: - metric = ValidIQL() +def test_iql_filters_recall(evaluation_results: List[EvaluationResult]) -> None: + metric = IQLFiltersRecall() scores = metric.compute(evaluation_results) - assert scores["VAL_IQL"] == 1.0 + assert scores["IQL/FLT/RECALL"] == 0.6666666666666666 + + +def test_iql_filters_parseability(evaluation_results: List[EvaluationResult]) -> None: + metric = IQLFiltersParseability() + scores = metric.compute(evaluation_results) + assert scores["IQL/FLT/PARSEABILITY"] == 1 + + +def test_iql_filters_correctness(evaluation_results: List[EvaluationResult]) -> None: + metric = IQLFiltersCorrectness() + scores = metric.compute(evaluation_results) + assert scores["IQL/FLT/CORRECTNESS"] == 0.5 + + +def test_exact_match_sql(evaluation_results: List[EvaluationResult]) -> None: + metric = SQLExactMatch() + scores = metric.compute(evaluation_results) + assert scores["SQL/EM"] == 0.4 @pytest.mark.parametrize( "acc, avg_times, expected_ex, expected_ves", [ - ([True, False, True, True], [1.2, 1.2, 12.2, 12.2, 13.2, 13.2, 232.1, 232.1], 0.75, 0.75), - ([True, True, True, True], [1.2, 1.2, 12.2, 12.2, 13.2, 13.2, 232.1, 232.1], 1.0, 1.0), - ([False, False, False, False], [1.2, 1.2, 12.2, 12.2, 13.2, 13.2, 232.1, 232.1], 0.0, 0.0), - ([True, False, True, True], [1.2, 3.2, 12.2, 15.2, 13.2, 17.2, 232.1, 287.1], 0.75, 0.5960767767585372), - ([True, False, True, True], [3.2, 1.2, 15.2, 12.2, 17.2, 13.2, 287.1, 232.1], 0.75, 0.9726740826467557), + ([True, False, True, True, True], [1.2, 1.2, 12.2, 12.2, 13.2, 13.2, 232.1, 232.1, 3, 3], 0.8, 0.8), + ([True, True, True, True, True], [1.2, 1.2, 12.2, 12.2, 13.2, 13.2, 232.1, 232.1, 3, 3], 1.0, 1.0), + ([False, False, False, False, False], [1.2, 1.2, 12.2, 12.2, 13.2, 13.2, 232.1, 232.1, 3, 3], 0.0, 0.0), + ( + [True, False, True, True, True], + [1.2, 3.2, 12.2, 15.2, 13.2, 17.2, 232.1, 287.1, 3, 3], + 0.8, + 0.6566867943411235, + ), + ( + [True, False, True, True, True], + [3.2, 1.2, 15.2, 12.2, 17.2, 13.2, 287.1, 232.1, 3, 3], + 0.8, + 1.00057728666646, + ), ], ) def test_execution_accuracy( diff --git a/setup.cfg b/setup.cfg index 77b7ef07..8ea3dff2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -56,16 +56,9 @@ examples = pydantic_settings~=2.1.0 psycopg2-binary~=2.9.9 benchmarks = - asyncpg~=0.28.0 datasets~=2.20.0 - eval-type-backport~=0.1.3 hydra-core~=1.3.2 - loguru~=0.7.0 neptune~=1.6.3 - pydantic~=2.6.1 - pydantic-core~=2.16.2 - pydantic-settings~=2.0.3 - psycopg2-binary~=2.9.9 elasticsearch = elasticsearch~=8.13.1 gradio =