From 8f7a166693c184631290f51e7906ccfb49acf2c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Hordy=C5=84ski?= <26008518+mhordynski@users.noreply.github.com> Date: Tue, 28 May 2024 14:31:26 +0200 Subject: [PATCH] feat: freeform text2sql with static configuration (#36) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --------- Co-authored-by: Ludwik Trammer Co-authored-by: Michał Pstrąg --- docs/how-to/update_similarity_indexes.md | 84 +------- docs/reference/similarity/detector.md | 7 - examples/freeform.py | 75 +++++++ mkdocs.yml | 1 - src/dbally/collection.py | 40 ++-- src/dbally/prompts/prompt_template.py | 2 +- src/dbally/similarity/detector.py | 204 ------------------ src/dbally/views/base.py | 14 +- .../views/freeform/text2sql/__init__.py | 10 +- .../views/freeform/text2sql/_autodiscovery.py | 17 +- src/dbally/views/freeform/text2sql/_config.py | 75 ++++++- src/dbally/views/freeform/text2sql/_view.py | 153 +++++++++++-- src/dbally/views/structured.py | 21 +- src/dbally_cli/main.py | 9 +- src/dbally_cli/similarity.py | 49 ----- src/dbally_codegen/__init__.py | 0 tests/unit/similarity/test_cli.py | 77 ------- tests/unit/similarity/test_detector.py | 120 ----------- tests/unit/views/text2sql/test_view.py | 41 ++-- 19 files changed, 395 insertions(+), 604 deletions(-) delete mode 100644 docs/reference/similarity/detector.md create mode 100644 examples/freeform.py delete mode 100644 src/dbally/similarity/detector.py delete mode 100644 src/dbally_cli/similarity.py create mode 100644 src/dbally_codegen/__init__.py delete mode 100644 tests/unit/similarity/test_cli.py delete mode 100644 tests/unit/similarity/test_detector.py diff --git a/docs/how-to/update_similarity_indexes.md b/docs/how-to/update_similarity_indexes.md index c93de01a..763bc1a7 100644 --- a/docs/how-to/update_similarity_indexes.md +++ b/docs/how-to/update_similarity_indexes.md @@ -4,50 +4,16 @@ The Similarity Index is a feature provided by db-ally that takes user input and While Similarity Indexes can be used directly, they are usually used with [Views](../concepts/views.md), annotating arguments to filter methods. This technique lets db-ally automatically match user-provided arguments to the most similar value in the data source. You can see an example of using similarity indexes with views on the [Quickstart Part 2: Semantic Similarity](../quickstart/quickstart2.md) page. -Similarity Indexes are designed to index all possible values (e.g., on disk or in a different data store). Consequently, when the data source undergoes changes, the Similarity Index must update to reflect these alterations. This guide will explain different ways to update Similarity Indexes. +Similarity Indexes are designed to index all possible values (e.g., on disk or in a different data store). Consequently, when the data source undergoes changes, the Similarity Index must update to reflect these alterations. This guide will explain how to update Similarity Indexes in your code. -You can update the Similarity Index through Python code or via the db-ally CLI. The following sections explain how to update these indexes using both methods: +* [Update a Single Similarity Index](#update-a-single-similarity-index) +* [Update Similarity Indexes from all Views in a Collection](#update-similarity-indexes-from-all-views-in-a-collection) -* [Update Similarity Indexes via the CLI](#update-similarity-indexes-via-the-cli) -* [Update Similarity Indexes via Python Code](#update-similarity-indexes-via-python-code) - * [Update on a Single Similarity Index](#update-on-a-single-similarity-index) - * [Update Similarity Indexes from all Views in a Collection](#update-similarity-indexes-from-all-views-in-a-collection) - * [Detect Similarity Indexes in Views](#detect-similarity-indexes-in-views) - -## Update Similarity Indexes via the CLI - -To update Similarity Indexes via the CLI, you can use the `dbally update-index` command. This command requires a path to what you wish to update. The path should follow this format: "path.to.module:ViewName.method_name.argument_name" where each part after the colon is optional. The more specific your target is, the fewer Similarity Indexes will be updated. - -For example, to update all Similarity Indexes in a module `my_module.views`, use this command: - -```bash -dbally update-index my_module.views -``` - -To update all Similarity Indexes in a specific View, add the name of the View following the module path: - -```bash -dbally update-index my_module.views:MyView -``` - -To update all Similarity Indexes within a specific method of a View, add the method's name after the View name: - -```bash -dbally update-index my_module.views:MyView.method_name -``` - -Lastly, to update all Similarity Indexes in a particular argument of a method, add the argument name after the method name: - -```bash -dbally update-index my_module.views:MyView.method_name.argument_name -``` - -## Update Similarity Indexes via Python Code -### Update on a Single Similarity Index +## Update a Single Similarity Index To manually update a Similarity Index, call the `update` method on the Similarity Index object. The `update` method will re-fetch all possible values from the data source and re-index them. Below is an example of how to manually update a Similarity Index: ```python -from db_ally import SimilarityIndex +from dbally import SimilarityIndex # Create a similarity index similarity_index = SimilarityIndex(fetcher=fetcher, store=store) @@ -56,14 +22,14 @@ similarity_index = SimilarityIndex(fetcher=fetcher, store=store) await similarity_index.update() ``` -### Update Similarity Indexes from all Views in a Collection +## Update Similarity Indexes from all Views in a Collection If you have a [collection](../concepts/collections.md) and want to update Similarity Indexes in all views, you can use the `update_similarity_indexes` method. This method will update all Similarity Indexes in all views registered with the collection: ```python -from db_ally import create_collection -from db_ally.llms.litellm import LiteLLM +from dbally import create_collection +from dbally.llms.litellm import LiteLLM -my_collection = create_collection("collection_name", llm=LiteLLM()) +my_collection = create_collection("my_collection", llm=LiteLLM()) # ... add views to the collection @@ -72,35 +38,3 @@ await my_collection.update_similarity_indexes() !!! info Alternatively, for more advanced use cases, you can use Collection's [`get_similarity_indexes`][dbally.Collection.get_similarity_indexes] method to get a list of all Similarity Indexes (allongside the places where they are used) and update them individually. - -### Detect Similarity Indexes in Views -If you are using Similarity Indexes to annotate arguments in views, you can use the [`SimilarityIndexDetector`][dbally.similarity.detector.SimilarityIndexDetector] to locate all Similarity Indexes in a view and update them. - -For example, to update all Similarity Indexes in a view named `MyView` in a module labeled `my_module.views`, use the following code: - -```python -from db_ally import SimilarityIndexDetector - -detector = SimilarityIndexDetector.from_path("my_module.views:MyView") -[await index.update() for index in detector.list_indexes()] -``` - -The `from_path` method constructs a `SimilarityIndexDetector` object from a view path string in the same format as the CLI command. The `list_indexes` method returns a list of Similarity Indexes detected in the view. - -For instance, to detect all Similarity Indexes in a module, provide only the path: - -```python -detector = SimilarityIndexDetector.from_path("my_module.views") -``` - -Conversely, to detect all Similarity Indexes in a specific method of a view, provide the method name: - -```python -detector = SimilarityIndexDetector.from_path("my_module.views:MyView.method_name") -``` - -Lastly, to detect all Similarity Indexes in a particular argument of a method, provide the argument name: - -```python -detector = SimilarityIndexDetector.from_path("my_module.views:MyView.method_name.argument_name") -``` \ No newline at end of file diff --git a/docs/reference/similarity/detector.md b/docs/reference/similarity/detector.md deleted file mode 100644 index 913d53d5..00000000 --- a/docs/reference/similarity/detector.md +++ /dev/null @@ -1,7 +0,0 @@ -# SimilarityIndexDetector - -SimilarityIndexDetector is a class that can be used to detect similarity indexes in views and update them. To see how to use it, see the [How-To: Update Similarity Indexes](../../how-to/update_similarity_indexes.md) guide. - -::: dbally.similarity.detector.SimilarityIndexDetector - -::: dbally.similarity.detector.SimilarityIndexDetectorException diff --git a/examples/freeform.py b/examples/freeform.py new file mode 100644 index 00000000..da3bea5e --- /dev/null +++ b/examples/freeform.py @@ -0,0 +1,75 @@ +import asyncio +from typing import List + +import sqlalchemy + +import dbally +from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler +from dbally.llms import LiteLLM +from dbally.views.freeform.text2sql import BaseText2SQLView, ColumnConfig, TableConfig + + +class MyText2SqlView(BaseText2SQLView): + """ + A Text2SQL view for the example. + """ + + def get_tables(self) -> List[TableConfig]: + """ + Get the tables used by the view. + + Returns: + A list of tables. + """ + return [ + TableConfig( + name="customers", + columns=[ + ColumnConfig("id", "SERIAL PRIMARY KEY"), + ColumnConfig("name", "VARCHAR(255)"), + ColumnConfig("city", "VARCHAR(255)"), + ColumnConfig("country", "VARCHAR(255)"), + ColumnConfig("age", "INTEGER"), + ], + ), + TableConfig( + name="products", + columns=[ + ColumnConfig("id", "SERIAL PRIMARY KEY"), + ColumnConfig("name", "VARCHAR(255)"), + ColumnConfig("category", "VARCHAR(255)"), + ColumnConfig("price", "REAL"), + ], + ), + TableConfig( + name="purchases", + columns=[ + ColumnConfig("customer_id", "INTEGER"), + ColumnConfig("product_id", "INTEGER"), + ColumnConfig("quantity", "INTEGER"), + ColumnConfig("date", "TEXT"), + ], + ), + ] + + +async def main(): + """Main function to run the example.""" + engine = sqlalchemy.create_engine("sqlite:///:memory:") + + # Create tables from config + with engine.connect() as connection: + for table_config in MyText2SqlView(engine).get_tables(): + connection.execute(sqlalchemy.text(table_config.ddl)) + + llm = LiteLLM() + collection = dbally.create_collection("text2sql", llm=llm, event_handlers=[CLIEventHandler()]) + collection.add(MyText2SqlView, lambda: MyText2SqlView(engine)) + + await collection.ask("What are the names of products bought by customers from London?") + await collection.ask("Which customers bought products from the category 'electronics'?") + await collection.ask("What is the total quantity of products bought by customers from the UK?") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/mkdocs.yml b/mkdocs.yml index 637842cf..9e1b971f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -65,7 +65,6 @@ nav: - reference/similarity/similarity_fetcher/index.md - reference/similarity/similarity_fetcher/sqlalchemy.md - reference/similarity/similarity_fetcher/sqlalchemy_simple.md - - reference/similarity/detector.md - Embeddings: - reference/embeddings/index.md - reference/embeddings/litellm.md diff --git a/src/dbally/collection.py b/src/dbally/collection.py index d31c106e..922a6365 100644 --- a/src/dbally/collection.py +++ b/src/dbally/collection.py @@ -2,7 +2,8 @@ import inspect import textwrap import time -from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar +from collections import defaultdict +from typing import Callable, Dict, List, Optional, Type, TypeVar from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker @@ -14,8 +15,7 @@ from dbally.similarity.index import AbstractSimilarityIndex from dbally.utils.errors import NoViewFoundError from dbally.view_selection.base import ViewSelector -from dbally.views.base import BaseView -from dbally.views.structured import BaseStructuredView +from dbally.views.base import BaseView, IndexLocation class IndexUpdateError(Exception): @@ -248,26 +248,22 @@ async def ask( return result - def get_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[Tuple[str, str, str]]]: + def get_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLocation]]: """ - List all similarity indexes from all structured views in the collection. + List all similarity indexes from all views in the collection. Returns: - Dictionary with similarity indexes as keys and values containing lists of places where they are used - (represented by a tupple containing view name, method name and argument name) + Mapping of similarity indexes to their locations, following view format. + For: + - freeform views, the format is (view_name, table_name, column_name) + - structured views, the format is (view_name, filter_name, argument_name) """ - indexes: Dict[AbstractSimilarityIndex, List[Tuple[str, str, str]]] = {} + indexes = defaultdict(list) for view_name in self._views: view = self.get(view_name) - - if not isinstance(view, BaseStructuredView): - continue - - filters = view.list_filters() - for filter_ in filters: - for param in filter_.parameters: - if param.similarity_index: - indexes.setdefault(param.similarity_index, []).append((view_name, filter_.name, param.name)) + view_indexes = view.list_similarity_indexes() + for index, location in view_indexes.items(): + indexes[index].extend(location) return indexes async def update_similarity_indexes(self) -> None: @@ -280,14 +276,12 @@ async def update_similarity_indexes(self) -> None: the dictionary were updated successfully. """ indexes = self.get_similarity_indexes() - update_corutines = [index.update() for index in indexes] - results = await asyncio.gather(*update_corutines, return_exceptions=True) + update_coroutines = [index.update() for index in indexes] + results = await asyncio.gather(*update_coroutines, return_exceptions=True) failed_indexes = { index: exception for index, exception in zip(indexes, results) if isinstance(exception, Exception) } if failed_indexes: failed_locations = [loc for index in failed_indexes for loc in indexes[index]] - description = ", ".join( - f"{view_name}.{method_name}.{param_name}" for view_name, method_name, param_name in failed_locations - ) - raise IndexUpdateError(f"Failed to update similarity indexes for {description}", failed_indexes) + descriptions = ", ".join(".".join(name for name in location) for location in failed_locations) + raise IndexUpdateError(f"Failed to update similarity indexes for {descriptions}", failed_indexes) diff --git a/src/dbally/prompts/prompt_template.py b/src/dbally/prompts/prompt_template.py index 2bd382f6..8e2746fe 100644 --- a/src/dbally/prompts/prompt_template.py +++ b/src/dbally/prompts/prompt_template.py @@ -41,7 +41,7 @@ class PromptTemplate: Class for prompt templates Attributes: - response_format: Optional argument used in the OpenAI API - used to force json output + response_format: Optional argument for OpenAI Turbo models - may be used to force json output llm_response_parser: Function parsing the LLM response into IQL """ diff --git a/src/dbally/similarity/detector.py b/src/dbally/similarity/detector.py deleted file mode 100644 index 83592936..00000000 --- a/src/dbally/similarity/detector.py +++ /dev/null @@ -1,204 +0,0 @@ -import importlib -from types import ModuleType -from typing import Any, Dict, List, Optional, Type - -from dbally.similarity import AbstractSimilarityIndex -from dbally.views import decorators -from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from dbally.views.methods_base import MethodsBaseView - - -class SimilarityIndexDetectorException(Exception): - """ - Exception that occured during similarity index discovery - """ - - def __init__(self, message: str): - self.message = message - super().__init__(message) - - def __str__(self) -> str: - return self.message - - -class SimilarityIndexDetector: - """ - Class used to detect similarity indexes. Works with method-based views that inherit - from MethodsBaseView (including all built-in dbally views). Automatically detects similarity - indexes on arguments of view's filter methods. - - Args: - module: The module to search for similarity indexes - chosen_view_name: The name of the view to search in (optional, all views if None) - chosen_method_name: The name of the method to search in (optional, all methods if None) - chosen_argument_name: The name of the argument to search in (optional, all arguments if None) - """ - - def __init__( - self, - module: ModuleType, - chosen_view_name: Optional[str] = None, - chosen_method_name: Optional[str] = None, - chosen_argument_name: Optional[str] = None, - ): - self.module = module - self.chosen_view_name = chosen_view_name - self.chosen_method_name = chosen_method_name - self.chosen_argument_name = chosen_argument_name - - @classmethod - def from_path(cls, path: str) -> "SimilarityIndexDetector": - """ - Create a SimilarityIndexDetector object from a path string in the format - "path.to.module:ViewName.method_name.argument_name" where each part after the - colon is optional. - - Args: - path: The path to the object - - Returns: - The SimilarityIndexDetector object - - Raises: - SimilarityIndexDetectorException: If the module is not found - """ - module_path, *object_path = path.split(":") - object_parts = object_path[0].split(".") if object_path else [] - chosen_view_name = object_parts[0] if object_parts else None - chosen_method_name = object_parts[1] if len(object_parts) > 1 else None - chosen_argument_name = object_parts[2] if len(object_parts) > 2 else None - - module = cls.get_module_from_path(module_path) - return cls(module, chosen_view_name, chosen_method_name, chosen_argument_name) - - @staticmethod - def get_module_from_path(module_path: str) -> ModuleType: - """ - Get the module from the given path - - Args: - module_path: The path to the module - - Returns: - The module - - Raises: - SimilarityIndexDetectorException: If the module is not found - """ - try: - module = importlib.import_module(module_path) - except ModuleNotFoundError as exc: - raise SimilarityIndexDetectorException(f"Module {module_path} not found.") from exc - return module - - def _is_methods_base_view(self, obj: Any) -> bool: - """ - Check if the given object is a subclass of MethodsBaseView - """ - return isinstance(obj, type) and issubclass(obj, MethodsBaseView) and obj is not MethodsBaseView - - def list_views(self) -> List[Type[MethodsBaseView]]: - """ - List method-based views in the module, filtering by the chosen view name if given during initialization. - - Returns: - List of views - - Raises: - SimilarityIndexDetectorException: If the chosen view is not found - """ - views = [ - getattr(self.module, name) - for name in dir(self.module) - if self._is_methods_base_view(getattr(self.module, name)) - ] - if self.chosen_view_name: - views = [view for view in views if view.__name__ == self.chosen_view_name] - if not views: - raise SimilarityIndexDetectorException( - f"View {self.chosen_view_name} not found in module {self.module.__name__}." - ) - return views - - def list_filters(self, view: Type[MethodsBaseView]) -> List[ExposedFunction]: - """ - List filters in the given view, filtering by the chosen method name if given during initialization. - - Args: - view: The view - - Returns: - List of filter names - - Raises: - SimilarityIndexDetectorException: If the chosen method is not found - """ - methods = view.list_methods_by_decorator(decorators.view_filter) - if self.chosen_method_name: - methods = [method for method in methods if method.name == self.chosen_method_name] - if not methods: - raise SimilarityIndexDetectorException( - f"Filter method {self.chosen_method_name} not found in view {view.__name__}." - ) - return methods - - def list_arguments(self, method: ExposedFunction) -> List[MethodParamWithTyping]: - """ - List arguments in the given method, filtering by the chosen argument name if given during initialization. - - Args: - method: The method - - Returns: - List of argument names - - Raises: - SimilarityIndexDetectorException: If the chosen argument is not found - """ - parameters = method.parameters - if self.chosen_argument_name: - parameters = [parameter for parameter in parameters if parameter.name == self.chosen_argument_name] - if not parameters: - raise SimilarityIndexDetectorException( - f"Argument {self.chosen_argument_name} not found in method {method.name}." - ) - return parameters - - def list_indexes(self, view: Optional[Type[MethodsBaseView]] = None) -> Dict[AbstractSimilarityIndex, List[str]]: - """ - List similarity indexes in the module, filtering by the chosen view, method and argument names if given - during initialization. - - Args: - view: The view to search in (optional, all views if None) - - Returns: - Dictionary mapping indexes to method arguments that use them - - Raises: - SimilarityIndexDetectorException: If any of the chosen path parts is not found - """ - indexes: Dict[AbstractSimilarityIndex, List[str]] = {} - views = self.list_views() if view is None else [view] - for view_class in views: - for method in self.list_filters(view_class): - for parameter in self.list_arguments(method): - if parameter.similarity_index: - indexes.setdefault(parameter.similarity_index, []).append( - f"{view_class.__name__}.{method.name}.{parameter.name}" - ) - return indexes - - async def update_indexes(self) -> None: - """ - Update similarity indexes in the module, filtering by the chosen view, method and argument names if given - during initialization. - - Raises: - SimilarityIndexDetectorException: If any of the chosen path parts is not found - """ - indexes = self.list_indexes() - if not indexes: - raise SimilarityIndexDetectorException("No similarity indexes found.") - for index in indexes: - await index.update() diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index 24a86822..2be62049 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -1,10 +1,13 @@ import abc -from typing import Optional +from typing import Dict, List, Optional, Tuple from dbally.audit.event_tracker import EventTracker from dbally.data_models.execution_result import ViewExecutionResult from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions +from dbally.similarity import AbstractSimilarityIndex + +IndexLocation = Tuple[str, str, str] class BaseView(metaclass=abc.ABCMeta): @@ -37,3 +40,12 @@ async def ask( Returns: The result of the query. """ + + def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLocation]]: + """ + Lists all the similarity indexes used by the view. + + Returns: + Mapping of similarity indexes to their locations. + """ + return {} diff --git a/src/dbally/views/freeform/text2sql/__init__.py b/src/dbally/views/freeform/text2sql/__init__.py index c1a32ecb..28f8f56e 100644 --- a/src/dbally/views/freeform/text2sql/__init__.py +++ b/src/dbally/views/freeform/text2sql/__init__.py @@ -1,10 +1,14 @@ from ._autodiscovery import AutoDiscoveryBuilder, AutoDiscoveryBuilderWithLLM, configure_text2sql_auto_discovery -from ._config import Text2SQLConfig -from ._view import Text2SQLFreeformView +from ._config import ColumnConfig, TableConfig, Text2SQLConfig, Text2SQLSimilarityType, Text2SQLTableConfig +from ._view import BaseText2SQLView __all__ = [ + "TableConfig", + "ColumnConfig", "Text2SQLConfig", - "Text2SQLFreeformView", + "Text2SQLTableConfig", + "Text2SQLSimilarityType", + "BaseText2SQLView", "configure_text2sql_auto_discovery", "AutoDiscoveryBuilder", "AutoDiscoveryBuilderWithLLM", diff --git a/src/dbally/views/freeform/text2sql/_autodiscovery.py b/src/dbally/views/freeform/text2sql/_autodiscovery.py index c649684f..45e3f69d 100644 --- a/src/dbally/views/freeform/text2sql/_autodiscovery.py +++ b/src/dbally/views/freeform/text2sql/_autodiscovery.py @@ -144,6 +144,17 @@ def generate_description_by_llm(self, example_rows_cnt: int = 5) -> Self: self._description_extraction = _LLMSummaryDescriptionExtraction(example_rows_cnt) return self + def suggest_similarity_indexes(self) -> Self: + """ + Enable the suggestion of similarity indexes for the columns in the tables. + The suggestion is based on the generated table descriptions and example values from the column. + + Returns: + The builder instance. + """ + self._similarity_enabled = True + return self + class AutoDiscoveryBuilder(_AutoDiscoveryBuilderBase): """ @@ -295,7 +306,11 @@ async def _suggest_similarity_indexes( template=similarity_template, fmt={"table_summary": description, "column_name": column.name, "values": example_values}, ) - similarity[column_name] = similarity_type + + similarity_type = similarity_type.upper() + + if similarity_type in ["SEMANTIC", "TRIGRAM"]: + similarity[column_name] = similarity_type return similarity diff --git a/src/dbally/views/freeform/text2sql/_config.py b/src/dbally/views/freeform/text2sql/_config.py index 648a1062..d95a09ba 100644 --- a/src/dbally/views/freeform/text2sql/_config.py +++ b/src/dbally/views/freeform/text2sql/_config.py @@ -1,9 +1,68 @@ from dataclasses import dataclass +from enum import Enum from pathlib import Path -from typing import Dict, Optional +from typing import Dict, Iterable, List, Optional, Tuple import yaml +from dbally.similarity import SimilarityIndex + + +class ColumnConfig: + """ + Configuration of a column used in the Text2SQL view. + """ + + def __init__( + self, + name: str, + data_type: str, + description: Optional[str] = None, + similarity_index: Optional[SimilarityIndex] = None, + ): + self.name = name + self.data_type = data_type + self.description = description + self.similarity_index = similarity_index + + +class TableConfig: + """ + Configuration of a table used in the Text2SQL view. + """ + + def __init__(self, name: str, columns: List[ColumnConfig], description: Optional[str] = None): + self.name = name + self.columns = columns + self.description = description + self._column_index = {column.name: column for column in columns} + + @property + def ddl(self) -> str: + """ + Returns the DDL for the table which can be provided to the LLM as a context. + + Returns: + The DDL for the table. + """ + return ( + f"CREATE TABLE {self.name} (" + + ", ".join(f"{column.name} {column.data_type}" for column in self.columns) + + ");" + ) + + def __getitem__(self, item: str) -> ColumnConfig: + return self._column_index[item] + + +class Text2SQLSimilarityType(str, Enum): + """ + Enum for the types of similarity indexes supported by Text2SQL. + """ + + SEMANTIC = "SEMANTIC" + TRIGRAM = "TRIGRAM" + @dataclass class Text2SQLTableConfig: @@ -13,7 +72,7 @@ class Text2SQLTableConfig: ddl: str description: Optional[str] = None - similarity: Optional[Dict[str, str]] = None + similarity: Optional[Dict[str, Text2SQLSimilarityType]] = None class Text2SQLConfig: @@ -48,3 +107,15 @@ def to_file(self, file_path: Path) -> None: """ data = {table_name: table.__dict__ for table_name, table in self.tables.items()} file_path.write_text(yaml.dump(data)) + + def iterate_similarity_indexes(self) -> Iterable[Tuple[str, str, Text2SQLSimilarityType]]: + """ + Iterate over the similarity indexes in the configuration. + + Yields: + The table name, column name, and similarity type. + """ + for table_name, table in self.tables.items(): + if table.similarity: + for column_name, similarity_type in table.similarity.items(): + yield table_name, column_name, similarity_type diff --git a/src/dbally/views/freeform/text2sql/_view.py b/src/dbally/views/freeform/text2sql/_view.py index a2610888..32f62db5 100644 --- a/src/dbally/views/freeform/text2sql/_view.py +++ b/src/dbally/views/freeform/text2sql/_view.py @@ -1,16 +1,21 @@ -from typing import Iterable, Optional, Tuple +import abc +import json +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple import sqlalchemy -from sqlalchemy import text +from sqlalchemy import ColumnClause, Table, text from dbally.audit.event_tracker import EventTracker from dbally.data_models.execution_result import ViewExecutionResult from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompts import PromptTemplate -from dbally.views.base import BaseView +from dbally.similarity import AbstractSimilarityIndex, SimpleSqlAlchemyFetcher +from dbally.views.base import BaseView, IndexLocation -from ._config import Text2SQLConfig +from ._config import TableConfig from ._errors import Text2SQLError text2sql_prompt = PromptTemplate( @@ -20,22 +25,88 @@ "content": "You are a very smart database programmer. " "You have access to the following {dialect} tables:\n" "{tables}\n" - "Create SQL query to answer user question. Return only SQL surrounded in ```sql ``` block.\n", + "Create SQL query to answer user question. Response with JSON containing following keys:\n\n" + "- sql: SQL query to answer the question, with parameter :placeholders for user input.\n" + "- parameters: a list of parameters to be used in the query, represented by maps with the following keys:\n" + " - name: the name of the parameter\n" + " - value: the value of the parameter\n" + " - table: the table the parameter is used with (if any)\n" + " - column: the column the parameter is compared to (if any)\n\n" + "Respond ONLY with the raw JSON response. Don't include any additional text or characters.", }, {"role": "user", "content": "{question}"}, ), + response_format={"type": "json_object"}, ) -class Text2SQLFreeformView(BaseView): +@dataclass +class SQLParameterOption: + """ + A class representing the options for a SQL parameter. + """ + + name: str + value: str + table: Optional[str] = None + column: Optional[str] = None + + # Maybe use pydantic instead of this method? + # On the other hand, it would introduce a new dependency + @staticmethod + def from_dict(data: Dict[str, str]) -> "SQLParameterOption": + """ + Creates an instance of SQLParameterOption from a dictionary. + + Args: + data: The dictionary to create the instance from. + + Returns: + An instance of SQLParameterOption. + + Raises: + ValueError: If the dictionary is invalid. + """ + if not isinstance(data, dict): + raise ValueError("Paramter data should be a dictionary") + + if "name" not in data or not isinstance(data["name"], str): + raise ValueError("Parameter name should be a string") + + if "value" not in data or not isinstance(data["value"], str): + raise ValueError(f"Value for parameter {data['name']} should be a string") + + if "table" in data and data["table"] is not None and not isinstance(data["table"], str): + raise ValueError(f"Table for parameter {data['name']} should be a string") + + if "column" in data and data["column"] is not None and not isinstance(data["column"], str): + raise ValueError(f"Column for parameter {data['name']} should be a string") + + return SQLParameterOption(data["name"], data["value"], data.get("table"), data.get("column")) + + async def value_with_similarity(self) -> str: + """ + Returns the value after passing it through a similarity index if available for the given table and column. + + Returns: + str: The value after passing it through a similarity index. + """ + # TODO: Lookup similarity index for the given volumn + return self.value + + +class BaseText2SQLView(BaseView, abc.ABC): """ Text2SQLFreeformView is a class designed to interact with the database using text2sql queries. """ - def __init__(self, engine: sqlalchemy.engine.Engine, config: Text2SQLConfig) -> None: + def __init__( + self, + engine: sqlalchemy.engine.Engine, + ) -> None: super().__init__() self._engine = engine - self._config = config + self._table_index = {table.name: table for table in self.get_tables()} async def ask( self, @@ -73,7 +144,7 @@ async def ask( # We want to catch all exceptions to retry the process. # pylint: disable=broad-except try: - sql, conversation = await self._generate_sql( + sql, parameters, conversation = await self._generate_sql( query=query, conversation=conversation, llm=llm, @@ -84,10 +155,10 @@ async def ask( if dry_run: return ViewExecutionResult(results=[], context={"sql": sql}) - rows = await self._execute_sql(sql) + rows = await self._execute_sql(sql, parameters, event_tracker=event_tracker) break except Exception as e: - conversation = conversation.add_user_message(f"Query is invalid! Error: {e}") + conversation = conversation.add_user_message(f"Response is invalid! Error: {e}") exceptions.append(e) continue @@ -110,7 +181,7 @@ async def _generate_sql( llm: LLM, event_tracker: EventTracker, llm_options: Optional[LLMOptions] = None, - ) -> Tuple[str, PromptTemplate]: + ) -> Tuple[str, List[SQLParameterOption], PromptTemplate]: response = await llm.generate_text( template=conversation, fmt={"tables": self._get_tables_context(), "dialect": self._engine.dialect.name, "question": query}, @@ -119,18 +190,64 @@ async def _generate_sql( ) conversation = conversation.add_assistant_message(response) + data = json.loads(response) + sql = data["sql"] + parameters = data.get("parameters", []) - response = response.split("```sql")[-1].strip("\n") - response = response.replace("```", "") - return response, conversation + if not isinstance(parameters, list): + raise ValueError("Parameters should be a list of dictionaries") + param_objs = [SQLParameterOption.from_dict(param) for param in parameters] + + return sql, param_objs, conversation + + async def _execute_sql( + self, sql: str, parameters: List[SQLParameterOption], event_tracker: EventTracker + ) -> Iterable: + param_values = {} + + for param in parameters: + if param.table in self._table_index and self._table_index[param.table][param.column].similarity_index: + similarity_index = self._table_index[param.table][param.column].similarity_index + param_values[param.name] = await similarity_index.similar(param.value, event_tracker=event_tracker) + else: + param_values[param.name] = param.value - async def _execute_sql(self, sql: str) -> Iterable: with self._engine.connect() as conn: - return conn.execute(text(sql)).fetchall() + return conn.execute(text(sql), param_values).fetchall() def _get_tables_context(self) -> str: context = "" - for table in self._config.tables.values(): + for table in self._table_index.values(): context += f"{table.ddl}\n" return context + + @abc.abstractmethod + def get_tables(self) -> List[TableConfig]: + """ + Get the tables used by the view. + + Returns: + A dictionary of tables. + """ + + def _create_default_fetcher(self, table: str, column: str) -> SimpleSqlAlchemyFetcher: + return SimpleSqlAlchemyFetcher( + sqlalchemy_engine=self._engine, + column=ColumnClause(column), + table=Table(table, sqlalchemy.MetaData()), + ) + + def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLocation]]: + """ + List all similarity indexes used by the view. + + Returns: + Mapping of similarity indexes to their locations in the (view_name, table_name, column_name) format. + """ + indexes = defaultdict(list) + for table in self.get_tables(): + for column in table.columns: + if column.similarity_index: + indexes[column.similarity_index].append((self.__class__.__name__, table.name, column.name)) + return indexes diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index e6034a18..536c2ab6 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -1,5 +1,6 @@ import abc -from typing import List, Optional +from collections import defaultdict +from typing import Dict, List, Optional from dbally.audit.event_tracker import EventTracker from dbally.data_models.execution_result import ViewExecutionResult @@ -9,7 +10,8 @@ from dbally.llms.clients.base import LLMOptions from dbally.views.exposed_functions import ExposedFunction -from .base import BaseView +from ..similarity import AbstractSimilarityIndex +from .base import BaseView, IndexLocation class BaseStructuredView(BaseView): @@ -110,3 +112,18 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: Args: dry_run: if True, should only generate the query without executing it """ + + def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLocation]]: + """ + Lists all the similarity indexes used by the view. + + Returns: + Mapping of similarity indexes to their locations in the (view_name, filter_name, argument_name) format. + """ + indexes = defaultdict(list) + filters = self.list_filters() + for filter_ in filters: + for param in filter_.parameters: + if param.similarity_index: + indexes[param.similarity_index].append((self.__class__.__name__, filter_.name, param.name)) + return indexes diff --git a/src/dbally_cli/main.py b/src/dbally_cli/main.py index 44fe6324..8912b899 100644 --- a/src/dbally_cli/main.py +++ b/src/dbally_cli/main.py @@ -1,11 +1,8 @@ import click -from dbally_cli.similarity import update_index - @click.group() def cli(): - """Command line tool for interacting with Db-ally""" - - -cli.add_command(update_index) + """ + Command line tool for interacting with dbally. + """ diff --git a/src/dbally_cli/similarity.py b/src/dbally_cli/similarity.py deleted file mode 100644 index f57ea547..00000000 --- a/src/dbally_cli/similarity.py +++ /dev/null @@ -1,49 +0,0 @@ -# pylint: disable=missing-param-doc -import asyncio -import sys - -import click - -from dbally.similarity.detector import SimilarityIndexDetector, SimilarityIndexDetectorException - - -@click.command(short_help="Updates similarity indexes based on the given object path.") -@click.argument("path") -def update_index(path: str): - """ - Updates similarity indexes based on the given object path, looking for - arguments on view filter methods that are annotated with similarity indexes. - - Works with method-based views that inherit from MethodsBaseView (including - all built-in dbally views). - - The path take one of the following formats: - - path.to.module - - path.to.module:ViewName - - path.to.module:ViewName.method_name - - path.to.module:ViewName.method_name.argument_name - - Less specific path will cause more indexes to be updated. For example, - - Path to a specific argument will update only the index for that argument. - - Path to a specific method will update indexes of all arguments of that filter method. - - Path to a specific view will update indexes of all arguments of all methods of that view. - - Path to a module will update indexes of all arguments of all methods of all views of that module. - """ - click.echo(f"Looking for similarity indexes in {path}...") - - try: - updater = SimilarityIndexDetector.from_path(path) - indexes = updater.list_indexes() - except SimilarityIndexDetectorException as exc: - click.echo(exc.message, err=True) - sys.exit(1) - - if not indexes: - click.echo("No similarity indexes found.", err=True) - sys.exit(1) - - for index, index_users in indexes.items(): - click.echo(f"Updating index used by {', '.join(index_users)}...") - asyncio.run(index.update()) - - click.echo("Indexes updated successfully.") diff --git a/src/dbally_codegen/__init__.py b/src/dbally_codegen/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/similarity/test_cli.py b/tests/unit/similarity/test_cli.py deleted file mode 100644 index e970f75c..00000000 --- a/tests/unit/similarity/test_cli.py +++ /dev/null @@ -1,77 +0,0 @@ -from click.testing import CliRunner - -from dbally_cli.similarity import update_index - - -def test_cli(): - """ - Test the update_index command with a specific module path - """ - runner = CliRunner() - result = runner.invoke(update_index, ["sample_module.submodule"]) - assert result.exit_code == 0 - assert "BarView.method_baz.person, FooView.method_bar.year" in result.output - assert "FooView.method_bar.city, FooView.method_foo.idx" in result.output - assert "Indexes updated successfully" in result.output - - -def test_cli_with_view(): - """ - Test the update_index command with a specific module and view path - """ - runner = CliRunner() - result = runner.invoke(update_index, ["sample_module.submodule:FooView"]) - assert result.exit_code == 0 - assert "FooView.method_bar.year" in result.output - assert "FooView.method_bar.city, FooView.method_foo.idx" in result.output - assert "Indexes updated successfully" in result.output - - -def test_cli_no_indexes(): - """ - Test the update_index command with a module that has no indexes - """ - runner = CliRunner() - result = runner.invoke(update_index, ["sample_module.empty_submodule"]) - assert result.exit_code != 0 - assert "No similarity indexes found" in result.output - - -def test_cli_with_invalid_path(): - """ - Test the update_index command with an invalid path - """ - runner = CliRunner() - result = runner.invoke(update_index, ["sample_module.invalid_submodule"]) - assert result.exit_code != 0 - assert "Module sample_module.invalid_submodule not found" in result.output - - -def test_cli_with_invalid_view(): - """ - Test the update_index command with an invalid view path - """ - runner = CliRunner() - result = runner.invoke(update_index, ["sample_module.submodule:InvalidView"]) - assert result.exit_code != 0 - assert "View InvalidView not found in module sample_module.submodule" in result.output - - -def test_cli_with_invalid_method(): - """ - Test the update_index command with an invalid method path - """ - runner = CliRunner() - result = runner.invoke(update_index, ["sample_module.submodule:FooView.invalid_method"]) - assert result.exit_code != 0 - assert "Filter method invalid_method not found in view FooView" in result.output - - -def test_cli_with_invalid_argument(): - """ - Test the update_index command with an invalid argument path - """ - runner = CliRunner() - result = runner.invoke(update_index, ["sample_module.submodule:FooView.method_bar.invalid_argument"]) - assert result.exit_code != 0 - assert "Argument invalid_argument not found in method method_bar" in result.output diff --git a/tests/unit/similarity/test_detector.py b/tests/unit/similarity/test_detector.py deleted file mode 100644 index 5c5a406c..00000000 --- a/tests/unit/similarity/test_detector.py +++ /dev/null @@ -1,120 +0,0 @@ -import pytest -from sample_module import submodule - -from dbally.similarity.detector import SimilarityIndexDetector, SimilarityIndexDetectorException - - -def test_detector_with_module(): - """ - Test the SimilarityIndexDetector class with a specific module object - """ - detector = SimilarityIndexDetector(submodule) - assert detector.list_views() == [submodule.BarView, submodule.FooView] - assert detector.list_indexes() == { - submodule.index_bar: ["BarView.method_baz.person", "FooView.method_bar.year"], - submodule.index_foo: ["FooView.method_bar.city", "FooView.method_foo.idx"], - } - - -def test_detector_with_view(): - """ - Test the SimilarityIndexDetector class with a specific view object - """ - detector = SimilarityIndexDetector(submodule) - assert detector.list_indexes(view=submodule.FooView) == { - submodule.index_bar: ["FooView.method_bar.year"], - submodule.index_foo: ["FooView.method_bar.city", "FooView.method_foo.idx"], - } - - -def test_detectior_with_module_path(): - """ - Test the SimilarityIndexDetector class with a module path - """ - detector = SimilarityIndexDetector.from_path("sample_module.submodule") - assert detector.list_views() == [submodule.BarView, submodule.FooView] - assert detector.list_indexes() == { - submodule.index_bar: ["BarView.method_baz.person", "FooView.method_bar.year"], - submodule.index_foo: ["FooView.method_bar.city", "FooView.method_foo.idx"], - } - - -def test_detector_with_module_view_path(): - """ - Test the SimilarityIndexDetector class with a module and view path - """ - detector = SimilarityIndexDetector.from_path("sample_module.submodule:FooView") - assert detector.list_views() == [submodule.FooView] - assert detector.list_indexes() == { - submodule.index_bar: ["FooView.method_bar.year"], - submodule.index_foo: ["FooView.method_bar.city", "FooView.method_foo.idx"], - } - - -def test_detector_with_module_view_method_path(): - """ - Test the SimilarityIndexDetector class with a module, view, and method path - """ - detector = SimilarityIndexDetector.from_path("sample_module.submodule:FooView.method_bar") - assert detector.list_views() == [submodule.FooView] - assert detector.list_indexes() == { - submodule.index_bar: ["FooView.method_bar.year"], - submodule.index_foo: ["FooView.method_bar.city"], - } - - -def test_detector_with_module_view_method_argument_path(): - """ - Test the SimilarityIndexDetector class with a module, view, method, and argument path - """ - detector = SimilarityIndexDetector.from_path("sample_module.submodule:FooView.method_bar.city") - assert detector.list_views() == [submodule.FooView] - assert detector.list_indexes() == {submodule.index_foo: ["FooView.method_bar.city"]} - - -def test_detector_with_module_not_found(): - """ - Test the SimilarityIndexDetector class with a module that does not exist - """ - with pytest.raises(SimilarityIndexDetectorException) as exc: - SimilarityIndexDetector.from_path("not_found") - assert exc.value.message == "Module not_found not found." - - -def test_detector_with_empty_module(): - """ - Test the SimilarityIndexDetector class with an empty module - """ - detector = SimilarityIndexDetector.from_path("sample_module.empty_submodule") - assert detector.list_views() == [] - assert not detector.list_indexes() - - -def test_detector_with_view_not_found(): - """ - Test the SimilarityIndexDetector class with a view that does not exist - """ - detector = SimilarityIndexDetector.from_path("sample_module.submodule:NotFoundView") - with pytest.raises(SimilarityIndexDetectorException) as exc: - detector.list_views() - assert exc.value.message == "View NotFoundView not found in module sample_module.submodule." - - -def test_detector_with_method_not_found(): - """ - Test the SimilarityIndexDetector class with a method that does not exist - """ - detector = SimilarityIndexDetector.from_path("sample_module.submodule:FooView.not_found") - with pytest.raises(SimilarityIndexDetectorException) as exc: - detector.list_indexes() - assert exc.value.message == "Filter method not_found not found in view FooView." - - -def test_detector_with_argument_not_found(): - """ - Test the SimilarityIndexDetector class with an argument that does not exist - """ - detector = SimilarityIndexDetector.from_path("sample_module.submodule:FooView.method_bar.not_found") - with pytest.raises(SimilarityIndexDetectorException) as exc: - detector.list_indexes() - assert exc.value.message == "Argument not_found not found in method method_bar." diff --git a/tests/unit/views/text2sql/test_view.py b/tests/unit/views/text2sql/test_view.py index a33ebba4..b8f1e44e 100644 --- a/tests/unit/views/text2sql/test_view.py +++ b/tests/unit/views/text2sql/test_view.py @@ -1,3 +1,4 @@ +import json from unittest.mock import AsyncMock import pytest @@ -5,11 +6,28 @@ from sqlalchemy import Engine, text import dbally -from dbally.views.freeform.text2sql import Text2SQLConfig, Text2SQLFreeformView -from dbally.views.freeform.text2sql._config import Text2SQLTableConfig +from dbally.views.freeform.text2sql import BaseText2SQLView +from dbally.views.freeform.text2sql._config import ColumnConfig, TableConfig from tests.unit.mocks import MockLLM +class SampleText2SQLView(BaseText2SQLView): + def get_tables(self): + return [ + TableConfig( + name="customers", + columns=[ + ColumnConfig("id", "SERIAL PRIMARY KEY"), + ColumnConfig("name", "VARCHAR(255)"), + ColumnConfig("city", "VARCHAR(255)"), + ColumnConfig("country", "VARCHAR(255)"), + ColumnConfig("age", "INTEGER"), + ], + description="Customers table", + ) + ] + + @pytest.fixture def sample_db() -> Engine: engine = sqlalchemy.create_engine("sqlite:///:memory:") @@ -32,24 +50,19 @@ def sample_db() -> Engine: async def test_text2sql_view(sample_db: Engine): + llm_response = { + "sql": "SELECT * FROM customers WHERE city = :city", + "parameters": [{"name": "city", "value": "New York"}], + } llm = MockLLM() - llm.client.call = AsyncMock(return_value="SELECT * FROM customers WHERE city = 'New York'") - - config = Text2SQLConfig( - tables={ - "customers": Text2SQLTableConfig( - ddl="CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT, city TEXT)", - description="Customers table", - ) - } - ) + llm.client.call = AsyncMock(return_value=json.dumps(llm_response)) collection = dbally.create_collection(name="test_collection", llm=llm) - collection.add(Text2SQLFreeformView, lambda: Text2SQLFreeformView(sample_db, config)) + collection.add(SampleText2SQLView, lambda: SampleText2SQLView(sample_db)) response = await collection.ask("Show me customers from New York") - assert response.context["sql"] == "SELECT * FROM customers WHERE city = 'New York'" + assert response.context["sql"] == llm_response["sql"] assert response.results == [ {"id": 1, "name": "Alice", "city": "New York"}, {"id": 3, "name": "Charlie", "city": "New York"},