diff --git a/src/dbally/views/freeform/text2sql/__init__.py b/src/dbally/views/freeform/text2sql/__init__.py index 28f8f56e..8c1ad220 100644 --- a/src/dbally/views/freeform/text2sql/__init__.py +++ b/src/dbally/views/freeform/text2sql/__init__.py @@ -1,15 +1,10 @@ -from ._autodiscovery import AutoDiscoveryBuilder, AutoDiscoveryBuilderWithLLM, configure_text2sql_auto_discovery -from ._config import ColumnConfig, TableConfig, Text2SQLConfig, Text2SQLSimilarityType, Text2SQLTableConfig -from ._view import BaseText2SQLView +from .config import ColumnConfig, TableConfig +from .errors import Text2SQLError +from .view import BaseText2SQLView __all__ = [ - "TableConfig", - "ColumnConfig", - "Text2SQLConfig", - "Text2SQLTableConfig", - "Text2SQLSimilarityType", "BaseText2SQLView", - "configure_text2sql_auto_discovery", - "AutoDiscoveryBuilder", - "AutoDiscoveryBuilderWithLLM", + "ColumnConfig", + "TableConfig", + "Text2SQLError", ] diff --git a/src/dbally/views/freeform/text2sql/_autodiscovery.py b/src/dbally/views/freeform/text2sql/_autodiscovery.py deleted file mode 100644 index 45e3f69d..00000000 --- a/src/dbally/views/freeform/text2sql/_autodiscovery.py +++ /dev/null @@ -1,352 +0,0 @@ -from typing import Any, Dict, Iterator, List, Optional, Tuple - -from sqlalchemy import Column, Connection, Engine, MetaData, Table -from sqlalchemy.sql.ddl import CreateTable -from typing_extensions import Self - -from dbally.llms.base import LLM -from dbally.prompts import PromptTemplate - -from ._config import Text2SQLConfig, Text2SQLTableConfig - - -class _DescriptionExtractionStrategy: - """ - Base class for strategies of extracting descriptions from the database. - """ - - -class _DBCommentsDescriptionExtraction(_DescriptionExtractionStrategy): - """ - Strategy for extracting descriptions from the database comments. - """ - - -class _LLMSummaryDescriptionExtraction(_DescriptionExtractionStrategy): - """ - Strategy for extracting descriptions from the database using LLM. - """ - - def __init__(self, example_rows_cnt: int = 5): - self.example_rows_cnt = example_rows_cnt - - -class _AutoDiscoveryBuilderBase: - """ - Builder class for configuring the auto-discovery of the database for text2sql freeform view. - """ - - _llm: Optional[LLM] - _blacklist: Optional[List[str]] - _whitelist: Optional[List[str]] - _description_extraction: _DescriptionExtractionStrategy - _similarity_enabled: bool - - def __init__( - self, - engine: Engine, - blacklist: Optional[List[str]] = None, - whitelist: Optional[List[str]] = None, - description_extraction: Optional[_DescriptionExtractionStrategy] = None, - similarity_enabled: bool = False, - llm: Optional[LLM] = None, - ) -> None: - self._engine = engine - self._llm = llm - - self._blacklist = blacklist - self._whitelist = whitelist - self._description_extraction = description_extraction or _DBCommentsDescriptionExtraction() - self._similarity_enabled = similarity_enabled - - def with_blacklist(self, blacklist: List[str]) -> Self: - """ - Set the blacklist of tables to exclude from the auto-discovery. - - Args: - blacklist: List of table names to exclude from the auto-discovery. - - Returns: - The builder instance. - - Raises: - ValueError: If both a whitelist and a blacklist are set. - """ - if self._whitelist is not None: - raise ValueError("Cannot have both a blacklist and a whitelist") - - self._blacklist = blacklist - return self - - def with_whitelist(self, whitelist: List[str]) -> Self: - """ - Set the whitelist of tables to include in the auto-discovery. - - Args: - whitelist: List of table names to include in the auto-discovery. - - Returns: - The builder instance. - - Raises: - ValueError: If both a whitelist and a blacklist are set. - """ - if self._blacklist is not None: - raise ValueError("Cannot have both a blacklist and a whitelist") - - self._whitelist = whitelist - return self - - def extract_description_from_comments(self) -> Self: - """ - Use the comments field in the database as a source for the table descriptions. - - Returns: - The builder instance. - """ - self._description_extraction = _DescriptionExtractionStrategy() - return self - - async def discover(self) -> Text2SQLConfig: - """ - Discover the tables in the database and return the configuration object. - - Returns: - Text2SQLConfig: The configuration object for the text2sql freeform view. - """ - return await _Text2SQLAutoDiscovery( - llm=self._llm, - engine=self._engine, - whitelist=self._whitelist, - blacklist=self._blacklist, - description_extraction=self._description_extraction, - similarity_enabled=self._similarity_enabled, - ).discover() - - -class AutoDiscoveryBuilderWithLLM(_AutoDiscoveryBuilderBase): - """ - Builder class for configuring the auto-discovery of the database for text2sql freeform view. - It extends the base builder with the ability to use LLM for extra tasks. - """ - - def generate_description_by_llm(self, example_rows_cnt: int = 5) -> Self: - """ - Use LLM to generate descriptions for the tables. - The descriptions are generated based on the table DDL and a configured count of example rows. - - Args: - example_rows_cnt: The number of example rows to use for generating the description. - - Returns: - The builder instance. - """ - self._description_extraction = _LLMSummaryDescriptionExtraction(example_rows_cnt) - return self - - def suggest_similarity_indexes(self) -> Self: - """ - Enable the suggestion of similarity indexes for the columns in the tables. - The suggestion is based on the generated table descriptions and example values from the column. - - Returns: - The builder instance. - """ - self._similarity_enabled = True - return self - - -class AutoDiscoveryBuilder(_AutoDiscoveryBuilderBase): - """ - Builder class for configuring the auto-discovery of the database for text2sql freeform view. - """ - - def use_llm(self, llm: LLM) -> AutoDiscoveryBuilderWithLLM: - """ - Set the LLM client to use for generating descriptions. - - Args: - llm: The LLM client to use for generating descriptions. - - Returns: - The builder instance. - """ - return AutoDiscoveryBuilderWithLLM( - engine=self._engine, - whitelist=self._whitelist, - blacklist=self._blacklist, - description_extraction=self._description_extraction, - similarity_enabled=self._similarity_enabled, - llm=llm, - ) - - -def configure_text2sql_auto_discovery(engine: Engine) -> AutoDiscoveryBuilder: - """ - This function is used to automatically discover the tables in the database and generate a yaml file with the tables - and their columns. The yaml file is used to configure the Text2SQLFreeformView in the dbally library. - - Args: - engine: The SQLAlchemy engine object used to connect to the database. - - Returns: - The builder object used to configure the auto-discovery process. - """ - return AutoDiscoveryBuilder(engine) - - -discovery_template = PromptTemplate( - chat=( - { - "role": "system", - "content": ( - "You are a very smart database programmer. " - "You will be provided with {dialect} table definition and example rows from this table.\n" - "Create a concise summary of provided table. Do not list the columns." - ), - }, - {"role": "user", "content": "DDL:\n {table_ddl}\nExample rows:\n {example_rows}"}, - ), -) - -similarity_template = PromptTemplate( - chat=( - { - "role": "system", - "content": ( - "Determine whether to use SEMANTIC or TRIGRAM search based on the given column.\n" - "Use SEMANTIC when values are categorical or synonym may be used (for example department names).\n" - "Use TRIGRAM when values are sensitive to typos and synonyms does not make sense " - "(for example person or company names).\n" - "Use NONE when a search of the values does not make sense and explicit values should be utilized.\n" - "Return only one of the following options: SEMANTIC, TRIGRAM, or NONE." - ), - }, - { - "role": "user", - "content": "TABLE SUMMARY: {table_summary}\n" "COLUMN_NAME: {column_name}\n" "EXAMPLE_VALUES: {values}", - }, - ) -) - - -class _Text2SQLAutoDiscovery: - """ - Class for auto-discovery of the database for text2sql freeform view. - """ - - def __init__( - self, - engine: Engine, - description_extraction: _DescriptionExtractionStrategy, - whitelist: Optional[List[str]] = None, - llm: Optional[LLM] = None, - blacklist: Optional[List[str]] = None, - similarity_enabled: bool = False, - ) -> None: - self._llm = llm - self._engine = engine - self._whitelist = whitelist - self._blacklist = blacklist - self._description_extraction = description_extraction - self._similarity_enabled = similarity_enabled - - async def discover(self) -> Text2SQLConfig: - """ - Discover tables in the database and return the configuration object. - - Returns: - Text2SQLConfig: The configuration object for the text2sql freeform view. - - Raises: - ValueError: If the description extraction strategy is invalid. - """ - - connection = self._engine.connect() - tables = {} - - for table_name, table in self._iterate_tables(): - if self._whitelist is not None and table_name not in self._whitelist: - continue - - if self._blacklist is not None and table_name in self._blacklist: - continue - - ddl = self._get_table_ddl(table) - - if isinstance(self._description_extraction, _DBCommentsDescriptionExtraction): - description = table.comment - elif isinstance(self._description_extraction, _LLMSummaryDescriptionExtraction): - example_rows = self._get_example_rows(connection, table, self._description_extraction.example_rows_cnt) - description = await self._generate_llm_summary(ddl, example_rows) - else: - raise ValueError(f"Invalid description extraction strategy: {self._description_extraction}") - - if self._similarity_enabled: - similarity = await self._suggest_similarity_indexes(connection, description or "", table) - else: - similarity = None - - tables[table_name] = Text2SQLTableConfig(ddl=ddl, description=description, similarity=similarity) - - connection.close() - - return Text2SQLConfig(tables) - - async def _suggest_similarity_indexes( - self, connection: Connection, description: str, table: Table - ) -> Dict[str, str]: - if self._llm is None: - raise ValueError("LLM client is required for suggesting similarity indexes.") - - similarity = {} - for column_name, column in self._iterate_str_columns(table): - example_values = self._get_column_example_values(connection, table, column) - similarity_type = await self._llm.generate_text( - template=similarity_template, - fmt={"table_summary": description, "column_name": column.name, "values": example_values}, - ) - - similarity_type = similarity_type.upper() - - if similarity_type in ["SEMANTIC", "TRIGRAM"]: - similarity[column_name] = similarity_type - - return similarity - - async def _generate_llm_summary(self, ddl: str, example_rows: List[dict]) -> str: - if self._llm is None: - raise ValueError("LLM client is required for generating descriptions.") - - return await self._llm.generate_text( - template=discovery_template, - fmt={"dialect": self._engine.dialect.name, "table_ddl": ddl, "example_rows": example_rows}, - ) - - def _iterate_tables(self) -> Iterator[Tuple[str, Table]]: - meta = MetaData() - meta.reflect(bind=self._engine) - for table in meta.sorted_tables: - yield str(table.name), table - - @staticmethod - def _iterate_str_columns(table: Table) -> Iterator[Tuple[str, Column]]: - for column in table.columns.values(): - if column.type.python_type is str: - yield str(column.name), column - - @staticmethod - def _get_example_rows(connection: Connection, table: Table, n: int = 5) -> List[Dict[str, Any]]: - rows = connection.execute(table.select().limit(n)).fetchall() - - # The underscore is used by sqlalchemy to avoid conflicts with column names - # pylint: disable=protected-access - return [{str(k): v for k, v in dict(row._mapping).items()} for row in rows] - - @staticmethod - def _get_column_example_values(connection: Connection, table: Table, column: Column) -> List[Any]: - example_values = connection.execute(table.select().with_only_columns(column).distinct().limit(5)).fetchall() - return [x[0] for x in example_values] - - def _get_table_ddl(self, table: Table) -> str: - return str(CreateTable(table).compile(self._engine)) diff --git a/src/dbally/views/freeform/text2sql/_config.py b/src/dbally/views/freeform/text2sql/_config.py deleted file mode 100644 index d95a09ba..00000000 --- a/src/dbally/views/freeform/text2sql/_config.py +++ /dev/null @@ -1,121 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple - -import yaml - -from dbally.similarity import SimilarityIndex - - -class ColumnConfig: - """ - Configuration of a column used in the Text2SQL view. - """ - - def __init__( - self, - name: str, - data_type: str, - description: Optional[str] = None, - similarity_index: Optional[SimilarityIndex] = None, - ): - self.name = name - self.data_type = data_type - self.description = description - self.similarity_index = similarity_index - - -class TableConfig: - """ - Configuration of a table used in the Text2SQL view. - """ - - def __init__(self, name: str, columns: List[ColumnConfig], description: Optional[str] = None): - self.name = name - self.columns = columns - self.description = description - self._column_index = {column.name: column for column in columns} - - @property - def ddl(self) -> str: - """ - Returns the DDL for the table which can be provided to the LLM as a context. - - Returns: - The DDL for the table. - """ - return ( - f"CREATE TABLE {self.name} (" - + ", ".join(f"{column.name} {column.data_type}" for column in self.columns) - + ");" - ) - - def __getitem__(self, item: str) -> ColumnConfig: - return self._column_index[item] - - -class Text2SQLSimilarityType(str, Enum): - """ - Enum for the types of similarity indexes supported by Text2SQL. - """ - - SEMANTIC = "SEMANTIC" - TRIGRAM = "TRIGRAM" - - -@dataclass -class Text2SQLTableConfig: - """ - Configuration of a table used in the Text2SQL view. - """ - - ddl: str - description: Optional[str] = None - similarity: Optional[Dict[str, Text2SQLSimilarityType]] = None - - -class Text2SQLConfig: - """ - Configuration object for the Text2SQL freeform view. - """ - - def __init__(self, tables: Dict[str, Text2SQLTableConfig]): - self.tables = tables - - @classmethod - def from_file(cls, file_path: Path) -> "Text2SQLConfig": - """ - Load the configuration object from a file. - - Args: - file_path: Path to the file containing the configuration. - - Returns: - Text2SQLConfig: The configuration object. - """ - data = yaml.safe_load(file_path.read_text()) - tables = {table_name: Text2SQLTableConfig(**table) for table_name, table in data.items()} - return cls(tables=tables) - - def to_file(self, file_path: Path) -> None: - """ - Save the configuration object to a file. - - Args: - file_path: Path to the file where the configuration should be saved. - """ - data = {table_name: table.__dict__ for table_name, table in self.tables.items()} - file_path.write_text(yaml.dump(data)) - - def iterate_similarity_indexes(self) -> Iterable[Tuple[str, str, Text2SQLSimilarityType]]: - """ - Iterate over the similarity indexes in the configuration. - - Yields: - The table name, column name, and similarity type. - """ - for table_name, table in self.tables.items(): - if table.similarity: - for column_name, similarity_type in table.similarity.items(): - yield table_name, column_name, similarity_type diff --git a/src/dbally/views/freeform/text2sql/config.py b/src/dbally/views/freeform/text2sql/config.py new file mode 100644 index 00000000..35ca8bec --- /dev/null +++ b/src/dbally/views/freeform/text2sql/config.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass +from typing import List, Optional + +from dbally.similarity import SimilarityIndex + + +@dataclass +class ColumnConfig: + """ + Configuration of a column used in the Text2SQL view. + """ + + name: str + data_type: str + description: Optional[str] = None + similarity_index: Optional[SimilarityIndex] = None + + +@dataclass +class TableConfig: + """ + Configuration of a table used in the Text2SQL view. + """ + + name: str + columns: List[ColumnConfig] + description: Optional[str] = None + + def __post_init__(self) -> None: + self._column_index = {column.name: column for column in self.columns} + + @property + def ddl(self) -> str: + """ + Returns the DDL for the table which can be provided to the LLM as a context. + + Returns: + The DDL for the table. + """ + return ( + f"CREATE TABLE {self.name} (" + + ", ".join(f"{column.name} {column.data_type}" for column in self.columns) + + ");" + ) + + def __getitem__(self, item: str) -> ColumnConfig: + return self._column_index[item] diff --git a/src/dbally/views/freeform/text2sql/_errors.py b/src/dbally/views/freeform/text2sql/errors.py similarity index 100% rename from src/dbally/views/freeform/text2sql/_errors.py rename to src/dbally/views/freeform/text2sql/errors.py diff --git a/src/dbally/views/freeform/text2sql/_view.py b/src/dbally/views/freeform/text2sql/view.py similarity index 96% rename from src/dbally/views/freeform/text2sql/_view.py rename to src/dbally/views/freeform/text2sql/view.py index 32f62db5..6c28c5c9 100644 --- a/src/dbally/views/freeform/text2sql/_view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -1,11 +1,10 @@ -import abc import json +from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass from typing import Dict, Iterable, List, Optional, Tuple -import sqlalchemy -from sqlalchemy import ColumnClause, Table, text +from sqlalchemy import ColumnClause, Engine, MetaData, Table, text from dbally.audit.event_tracker import EventTracker from dbally.data_models.execution_result import ViewExecutionResult @@ -15,8 +14,8 @@ from dbally.similarity import AbstractSimilarityIndex, SimpleSqlAlchemyFetcher from dbally.views.base import BaseView, IndexLocation -from ._config import TableConfig -from ._errors import Text2SQLError +from .config import TableConfig +from .errors import Text2SQLError text2sql_prompt = PromptTemplate( chat=( @@ -95,19 +94,28 @@ async def value_with_similarity(self) -> str: return self.value -class BaseText2SQLView(BaseView, abc.ABC): +class BaseText2SQLView(BaseView, ABC): """ Text2SQLFreeformView is a class designed to interact with the database using text2sql queries. """ def __init__( self, - engine: sqlalchemy.engine.Engine, + engine: Engine, ) -> None: super().__init__() self._engine = engine self._table_index = {table.name: table for table in self.get_tables()} + @abstractmethod + def get_tables(self) -> List[TableConfig]: + """ + Get the tables used by the view. + + Returns: + The list of tables used by the view. + """ + async def ask( self, query: str, @@ -219,23 +227,13 @@ def _get_tables_context(self) -> str: context = "" for table in self._table_index.values(): context += f"{table.ddl}\n" - return context - @abc.abstractmethod - def get_tables(self) -> List[TableConfig]: - """ - Get the tables used by the view. - - Returns: - A dictionary of tables. - """ - def _create_default_fetcher(self, table: str, column: str) -> SimpleSqlAlchemyFetcher: return SimpleSqlAlchemyFetcher( sqlalchemy_engine=self._engine, column=ColumnClause(column), - table=Table(table, sqlalchemy.MetaData()), + table=Table(table, MetaData()), ) def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLocation]]: diff --git a/src/dbally_cli/main.py b/src/dbally_cli/main.py index 8912b899..1976debb 100644 --- a/src/dbally_cli/main.py +++ b/src/dbally_cli/main.py @@ -2,7 +2,7 @@ @click.group() -def cli(): +def cli() -> None: """ Command line tool for interacting with dbally. """ diff --git a/src/dbally_codegen/__init__.py b/src/dbally_codegen/__init__.py index e69de29b..bba874f9 100644 --- a/src/dbally_codegen/__init__.py +++ b/src/dbally_codegen/__init__.py @@ -0,0 +1,11 @@ +from dbally_codegen.autodiscovery import ( + AutoDiscoveryBuilder, + AutoDiscoveryBuilderWithLLM, + configure_text2sql_auto_discovery, +) + +__all__ = [ + "AutoDiscoveryBuilder", + "AutoDiscoveryBuilderWithLLM", + "configure_text2sql_auto_discovery", +] diff --git a/src/dbally_codegen/autodiscovery.py b/src/dbally_codegen/autodiscovery.py new file mode 100644 index 00000000..d115edf2 --- /dev/null +++ b/src/dbally_codegen/autodiscovery.py @@ -0,0 +1,445 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Iterator, List, Optional + +from sqlalchemy import Column, Connection, Engine, MetaData, Table +from sqlalchemy.sql.ddl import CreateTable +from typing_extensions import Self + +from dbally.llms.base import LLM +from dbally.prompts import PromptTemplate +from dbally.similarity.index import SimilarityIndex +from dbally.similarity.store import SimilarityStore +from dbally.views.freeform.text2sql import ColumnConfig, TableConfig + +DISCOVERY_TEMPLATE = PromptTemplate( + chat=( + { + "role": "system", + "content": ( + "You are a very smart database programmer. " + "You will be provided with {dialect} table definition and example rows from this table.\n" + "Create a concise summary of provided table. Do not list the columns." + ), + }, + { + "role": "user", + "content": "DDL:\n {table_ddl}\n" "EXAMPLE ROWS:\n {samples}", + }, + ), +) + +SIMILARITY_TEMPLATE = PromptTemplate( + chat=( + { + "role": "system", + "content": ( + "Determine whether to use semantic search based on the given column.\n" + "Return TRUE when values are categorical, or synonyms may be used (for example department names),\n" + "or when values are typo-sensitive (for example person or company names).\n" + "Return FALSE when a search of the values does not make sense and explicit values should be utilized.\n" + "Return only one of the following options: TRUE or FALSE." + ), + }, + { + "role": "user", + "content": "TABLE SUMMARY: {table_summary}\n" "COLUMN NAME: {column_name}\n" "EXAMPLE VALUES: {samples}", + }, + ) +) + + +class DescriptionExtractionStrategy(ABC): + """ + Base class for strategies of extracting descriptions from the database. + """ + + @abstractmethod + async def extract_description(self, table: Table, connection: Connection) -> str: + """ + Extract the description from the table. + + Args: + table: The table to extract the description from. + connection: The connection to the database. + + Returns: + The extracted description. + """ + + +class DBCommentsDescriptionExtraction(DescriptionExtractionStrategy): + """ + Strategy for extracting descriptions from the database comments. + """ + + async def extract_description(self, table: Table, connection: Connection) -> str: + """ + Extract the description from the table comments. + + Args: + table: The table to extract the description from. + connection: The connection to the database. + + Returns: + The extracted description. + """ + return table.comment or "" + + +class LLMSummaryDescriptionExtraction(DescriptionExtractionStrategy): + """ + Strategy for extracting descriptions from the database using LLM. + """ + + def __init__(self, llm: LLM, engine: Engine, samples_count: int = 5) -> None: + self.llm = llm + self.engine = engine + self.samples_count = samples_count + + async def extract_description(self, table: Table, connection: Connection) -> str: + """ + Extract the description from the table using LLM. + + Args: + table: The table to extract the description from. + connection: The connection to the database. + + Returns: + The extracted description. + """ + ddl = self._generate_ddl(table) + samples = self._fetch_samples(connection, table) + return await self.llm.generate_text( + template=DISCOVERY_TEMPLATE, + fmt={ + "dialect": self.engine.dialect.name, + "table_ddl": ddl, + "samples": samples, + }, + ) + + def _fetch_samples(self, connection: Connection, table: Table) -> List[Dict[str, Any]]: + rows = connection.execute(table.select().limit(self.samples_count)).fetchall() + + # The underscore is used by sqlalchemy to avoid conflicts with column names + # pylint: disable=protected-access + return [{str(k): v for k, v in dict(row._mapping).items()} for row in rows] + + def _generate_ddl(self, table: Table) -> str: + return str(CreateTable(table).compile(self.engine)) + + +class SimilarityIndexSelectionStrategy(ABC): + """ + Base class for strategies of selecting similarity indexes for the columns. + """ + + @abstractmethod + async def select_index( + self, + table: Table, + column: Column, + description: str, + connection: Connection, + ) -> Optional[SimilarityStore]: + """ + Select the similarity index for the column. + + Args: + table: The table of the column. + column: The column to select the index for. + description: The description of the table. + connection: The connection to the database. + + Returns: + The similarity index to use for the column or None if no index should be used. + """ + + +class NoSimilarityIndexSelection(SimilarityIndexSelectionStrategy): + """ + Strategy for not suggesting any similarity indexes. + """ + + async def select_index( + self, + table: Table, + column: Column, + description: str, + connection: Connection, + ) -> Optional[SimilarityStore]: + """ + Select the similarity index for the column. + + Args: + table: The table of the column. + column: The column to select the index for. + description: The description of the table. + connection: The connection to the database. + """ + return None + + +class LLMSuggestedSimilarityIndexSelection(SimilarityIndexSelectionStrategy): + """ + Strategy for suggesting similarity indexes for the columns using LLM. + """ + + def __init__( + self, + llm: LLM, + index_builder: Callable[[Engine, Table, Column], SimilarityIndex], + samples_count: int = 5, + ) -> None: + self.llm = llm + self.index_builder = index_builder + self.samples_count = samples_count + + async def select_index( + self, + table: Table, + column: Column, + description: str, + connection: Connection, + ) -> Optional[SimilarityStore]: + """ + Select the similarity index for the column using LLM. + + Args: + table: The table of the column. + column: The column to select the index for. + description: The description of the table. + connection: The connection to the database. + + Returns: + The similarity index to use for the column or None if no index should be used. + """ + samples = self._fetch_samples( + connection=connection, + table=table, + column=column, + ) + use_index = await self.llm.generate_text( + template=SIMILARITY_TEMPLATE, + fmt={ + "table_summary": description, + "column_name": column.name, + "samples": samples, + }, + ) + return self.index_builder(connection.engine, table, column) if use_index.upper() == "TRUE" else None + + def _fetch_samples(self, connection: Connection, table: Table, column: Column) -> List[Any]: + values = connection.execute( + table.select().with_only_columns(column).distinct().limit(self.samples_count) + ).fetchall() + return [value[0] for value in values] + + +class _AutoDiscoveryBuilderBase: + """ + Builder class for configuring the auto-discovery of the database for text2sql freeform view. + """ + + def __init__( + self, + engine: Engine, + llm: Optional[LLM] = None, + blacklist: Optional[List[str]] = None, + whitelist: Optional[List[str]] = None, + description_extraction: Optional[DescriptionExtractionStrategy] = None, + similarity_selection: Optional[SimilarityIndexSelectionStrategy] = None, + ) -> None: + self._engine = engine + self._llm = llm + self._blacklist = blacklist + self._whitelist = whitelist + self._description_extraction = description_extraction or DBCommentsDescriptionExtraction() + self._similarity_selection = similarity_selection or NoSimilarityIndexSelection() + + def with_blacklist(self, blacklist: List[str]) -> Self: + """ + Set the blacklist of tables to exclude from the auto-discovery. + + Args: + blacklist: List of table names to exclude from the auto-discovery. + + Returns: + The builder instance. + + Raises: + ValueError: If both a whitelist and a blacklist are set. + """ + if self._whitelist is not None: + raise ValueError("Cannot have both a blacklist and a whitelist") + + self._blacklist = blacklist + return self + + def with_whitelist(self, whitelist: List[str]) -> Self: + """ + Set the whitelist of tables to include in the auto-discovery. + + Args: + whitelist: List of table names to include in the auto-discovery. + + Returns: + The builder instance. + + Raises: + ValueError: If both a whitelist and a blacklist are set. + """ + if self._blacklist is not None: + raise ValueError("Cannot have both a blacklist and a whitelist") + + self._whitelist = whitelist + return self + + def extract_description_from_comments(self) -> Self: + """ + Use the comments field in the database as a source for the table descriptions. + + Returns: + The builder instance. + """ + self._description_extraction = DBCommentsDescriptionExtraction() + return self + + async def discover(self) -> List[TableConfig]: + """ + Discover tables in the database and return the configuration object. + + Returns: + List of tables with their columns and descriptions. + """ + with self._engine.connect() as connection: + tables = [] + for table in self._iterate_tables(): + if self._whitelist is not None and table.name not in self._whitelist: + continue + + if self._blacklist is not None and table.name in self._blacklist: + continue + + description = await self._description_extraction.extract_description(table, connection) + + columns = [] + for column in self._iterate_columns(table): + similarity_index = await self._similarity_selection.select_index( + table=table, + column=column, + description=description, + connection=connection, + ) + columns.append( + ColumnConfig( + name=column.name, + data_type=str(column.type), + similarity_index=similarity_index, + ) + ) + tables.append( + TableConfig( + name=table.name, + description=description, + columns=columns, + ) + ) + return tables + + def _iterate_tables(self) -> Iterator[Table]: + meta = MetaData() + meta.reflect(bind=self._engine) + yield from meta.sorted_tables + + @staticmethod + def _iterate_columns(table: Table) -> Iterator[Column]: + for column in table.columns.values(): + if column.type.python_type is str: + yield column + + +class AutoDiscoveryBuilderWithLLM(_AutoDiscoveryBuilderBase): + """ + Builder class for configuring the auto-discovery of the database for text2sql freeform view. + It extends the base builder with the ability to use LLM for extra tasks. + """ + + def generate_description_by_llm(self, samples_count: int = 5) -> Self: + """ + Use LLM to generate descriptions for the tables. + The descriptions are generated based on the table DDL and a configured count of example rows. + + Args: + samples_count: The number of example rows to use for generating the description. + + Returns: + The builder instance. + """ + self._description_extraction = LLMSummaryDescriptionExtraction( + llm=self._llm, + engine=self._engine, + samples_count=samples_count, + ) + return self + + def suggest_similarity_indexes( + self, + index_builder: Callable[[Engine, Table, Column], SimilarityIndex], + samples_count: int = 5, + ) -> Self: + """ + Enable the suggestion of similarity indexes for the columns in the tables. + The suggestion is based on the generated table descriptions and example values from the column. + + Args: + index_builder: The function used to build the similarity index. + samples_count: The number of example values to use for generating the suggestion. + + Returns: + The builder instance. + """ + self._similarity_selection = LLMSuggestedSimilarityIndexSelection( + llm=self._llm, + index_builder=index_builder, + samples_count=samples_count, + ) + return self + + +class AutoDiscoveryBuilder(_AutoDiscoveryBuilderBase): + """ + Builder class for configuring the auto-discovery of the database for text2sql freeform view. + """ + + def use_llm(self, llm: LLM) -> AutoDiscoveryBuilderWithLLM: + """ + Set the LLM client to use for generating descriptions. + + Args: + llm: The LLM client to use for generating descriptions. + + Returns: + The builder instance. + """ + return AutoDiscoveryBuilderWithLLM( + engine=self._engine, + llm=llm, + blacklist=self._blacklist, + whitelist=self._whitelist, + description_extraction=self._description_extraction, + similarity_selection=self._similarity_selection, + ) + + +def configure_text2sql_auto_discovery(engine: Engine) -> AutoDiscoveryBuilder: + """ + Create a builder object used to configure the auto-discovery process of the database for text2sql freeform view. + + Args: + engine: The SQLAlchemy engine object used to connect to the database. + + Returns: + The builder object used to configure the auto-discovery process. + """ + return AutoDiscoveryBuilder(engine) diff --git a/tests/unit/views/text2sql/test_autodiscovery.py b/tests/unit/views/text2sql/test_autodiscovery.py index 3321c841..b9b8faf0 100644 --- a/tests/unit/views/text2sql/test_autodiscovery.py +++ b/tests/unit/views/text2sql/test_autodiscovery.py @@ -4,7 +4,7 @@ import sqlalchemy from sqlalchemy import Engine, text -from dbally.views.freeform.text2sql import configure_text2sql_auto_discovery +from dbally_codegen.autodiscovery import configure_text2sql_auto_discovery @pytest.fixture @@ -39,44 +39,38 @@ def test_builder_cant_set_whitelist_and_blacklist(): async def test_autodiscovery_blacklist(sample_db: Engine): - config = await configure_text2sql_auto_discovery(sample_db).with_blacklist(["authentication"]).discover() + tables = await configure_text2sql_auto_discovery(sample_db).with_blacklist(["authentication"]).discover() + table_names = [table.name for table in tables] - assert len(config.tables) == 2 - - tables = config.tables - - assert "customers" in tables - assert "orders" in tables - assert "authentication" not in tables + assert len(table_names) == 2 + assert "customers" in table_names + assert "orders" in table_names + assert "authentication" not in table_names async def test_autodiscovery_whitelist(sample_db: Engine): - config = await configure_text2sql_auto_discovery(sample_db).with_whitelist(["customers", "orders"]).discover() - - assert len(config.tables) == 2 + tables = await configure_text2sql_auto_discovery(sample_db).with_whitelist(["customers", "orders"]).discover() + table_names = [table.name for table in tables] - tables = config.tables - - assert "customers" in tables - assert "orders" in tables - assert "authentication" not in tables + assert len(table_names) == 2 + assert "customers" in table_names + assert "orders" in table_names + assert "authentication" not in table_names async def test_autodiscovery_llm_descriptions(sample_db: Engine): mock_client = Mock() mock_client.generate_text = AsyncMock(return_value="LLM mock answer") - config = await ( + tables = await ( configure_text2sql_auto_discovery(sample_db) .with_blacklist(["authentication"]) .use_llm(mock_client) .generate_description_by_llm() .discover() ) + table_descriptions = [table.description for table in tables] - assert len(config.tables) == 2 - - tables = config.tables - - assert tables["customers"].description == "LLM mock answer" - assert tables["orders"].description == "LLM mock answer" + assert len(table_descriptions) == 2 + assert table_descriptions[0] == "LLM mock answer" + assert table_descriptions[1] == "LLM mock answer" diff --git a/tests/unit/views/text2sql/test_config.py b/tests/unit/views/text2sql/test_config.py deleted file mode 100644 index f43be06e..00000000 --- a/tests/unit/views/text2sql/test_config.py +++ /dev/null @@ -1,34 +0,0 @@ -import tempfile -from pathlib import Path - -from dbally.views.freeform.text2sql import Text2SQLConfig -from dbally.views.freeform.text2sql._config import Text2SQLTableConfig - - -def test_text2sql_config_persistence(): - config = Text2SQLConfig( - tables={ - "table1": Text2SQLTableConfig( - ddl="SAMPLE DDL1", description="SAMPLE DESCRIPTION1", similarity={"col1": "SEMANTIC", "col2": "TRIGRAM"} - ), - "table2": Text2SQLTableConfig( - ddl="SAMPLE DDL2", description="SAMPLE DESCRIPTION2", similarity={"col3": "SEMANTIC", "col4": "TRIGRAM"} - ), - } - ) - - with tempfile.NamedTemporaryFile() as f: - config.to_file(Path(f.name)) - loaded_config = Text2SQLConfig.from_file(Path(f.name)) - - assert len(config.tables) == len(loaded_config.tables) - - table1 = loaded_config.tables["table1"] - assert table1.ddl == "SAMPLE DDL1" - assert table1.description == "SAMPLE DESCRIPTION1" - assert table1.similarity == {"col1": "SEMANTIC", "col2": "TRIGRAM"} - - table2 = loaded_config.tables["table2"] - assert table2.ddl == "SAMPLE DDL2" - assert table2.description == "SAMPLE DESCRIPTION2" - assert table2.similarity == {"col3": "SEMANTIC", "col4": "TRIGRAM"} diff --git a/tests/unit/views/text2sql/test_view.py b/tests/unit/views/text2sql/test_view.py index b8f1e44e..91b7b50f 100644 --- a/tests/unit/views/text2sql/test_view.py +++ b/tests/unit/views/text2sql/test_view.py @@ -6,8 +6,7 @@ from sqlalchemy import Engine, text import dbally -from dbally.views.freeform.text2sql import BaseText2SQLView -from dbally.views.freeform.text2sql._config import ColumnConfig, TableConfig +from dbally.views.freeform.text2sql import BaseText2SQLView, ColumnConfig, TableConfig from tests.unit.mocks import MockLLM