From a381ce99f329384dbc5209f66d6465baf25d1979 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Wed, 17 Apr 2024 09:40:36 -0600 Subject: [PATCH 01/14] [DH-5733] Support schemas column to add a db connection --- dataherald/api/fastapi.py | 26 +--- dataherald/db_scanner/__init__.py | 1 + dataherald/db_scanner/models/types.py | 1 + dataherald/db_scanner/sqlalchemy.py | 2 + dataherald/sql_database/models/types.py | 1 + dataherald/sql_database/services/__init__.py | 0 .../services/database_connection.py | 119 ++++++++++++++++++ dataherald/types.py | 1 + 8 files changed, 130 insertions(+), 21 deletions(-) create mode 100644 dataherald/sql_database/services/__init__.py create mode 100644 dataherald/sql_database/services/database_connection.py diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index d854da05..3e36fded 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -70,6 +70,9 @@ SQLInjectionError, ) from dataherald.sql_database.models.types import DatabaseConnection +from dataherald.sql_database.services.database_connection import ( + DatabaseConnectionService, +) from dataherald.types import ( BaseLLM, CancelFineTuningRequest, @@ -173,27 +176,9 @@ def create_database_connection( self, database_connection_request: DatabaseConnectionRequest ) -> DatabaseConnectionResponse: try: - db_connection = DatabaseConnection( - alias=database_connection_request.alias, - connection_uri=database_connection_request.connection_uri.strip(), - path_to_credentials_file=database_connection_request.path_to_credentials_file, - llm_api_key=database_connection_request.llm_api_key, - use_ssh=database_connection_request.use_ssh, - ssh_settings=database_connection_request.ssh_settings, - file_storage=database_connection_request.file_storage, - metadata=database_connection_request.metadata, - ) - sql_database = SQLDatabase.get_sql_engine(db_connection, True) - - # Get tables and views and create table-descriptions as NOT_SCANNED - db_connection_repository = DatabaseConnectionRepository(self.storage) - - scanner_repository = TableDescriptionRepository(self.storage) scanner = self.system.instance(Scanner) - - tables = sql_database.get_tables_and_views() - db_connection = db_connection_repository.insert(db_connection) - scanner.create_tables(tables, str(db_connection.id), scanner_repository) + db_connection_service = DatabaseConnectionService(scanner, self.storage) + db_connection = db_connection_service.create(database_connection_request) except Exception as e: # Encrypt sensible values fernet_encrypt = FernetEncrypt() @@ -209,7 +194,6 @@ def create_database_connection( return error_response( e, database_connection_request.dict(), "invalid_database_connection" ) - return DatabaseConnectionResponse(**db_connection.dict()) @override diff --git a/dataherald/db_scanner/__init__.py b/dataherald/db_scanner/__init__.py index 67ea3bff..050a8870 100644 --- a/dataherald/db_scanner/__init__.py +++ b/dataherald/db_scanner/__init__.py @@ -34,6 +34,7 @@ def create_tables( self, tables: list[str], db_connection_id: str, + schema: str, repository: TableDescriptionRepository, metadata: dict = None, ) -> None: diff --git a/dataherald/db_scanner/models/types.py b/dataherald/db_scanner/models/types.py index aa737495..c6b33810 100644 --- a/dataherald/db_scanner/models/types.py +++ b/dataherald/db_scanner/models/types.py @@ -31,6 +31,7 @@ class TableDescriptionStatus(Enum): class TableDescription(BaseModel): id: str | None db_connection_id: str + schema_name: str | None table_name: str description: str | None table_schema: str | None diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index e1897083..51cc6d94 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -44,6 +44,7 @@ def create_tables( self, tables: list[str], db_connection_id: str, + schema: str, repository: TableDescriptionRepository, metadata: dict = None, ) -> None: @@ -51,6 +52,7 @@ def create_tables( repository.save_table_info( TableDescription( db_connection_id=db_connection_id, + schema_name=schema, table_name=table, status=TableDescriptionStatus.NOT_SCANNED.value, metadata=metadata, diff --git a/dataherald/sql_database/models/types.py b/dataherald/sql_database/models/types.py index 2a0cc0cc..538da8dd 100644 --- a/dataherald/sql_database/models/types.py +++ b/dataherald/sql_database/models/types.py @@ -96,6 +96,7 @@ class DatabaseConnection(BaseModel): dialect: SupportedDialects | None use_ssh: bool = False connection_uri: str | None + schemas: list[str] | None path_to_credentials_file: str | None llm_api_key: str | None = None ssh_settings: SSHSettings | None = None diff --git a/dataherald/sql_database/services/__init__.py b/dataherald/sql_database/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dataherald/sql_database/services/database_connection.py b/dataherald/sql_database/services/database_connection.py new file mode 100644 index 00000000..916c3839 --- /dev/null +++ b/dataherald/sql_database/services/database_connection.py @@ -0,0 +1,119 @@ +import re + +from sqlalchemy import inspect + +from dataherald.db import DB +from dataherald.db_scanner import Scanner +from dataherald.db_scanner.repository.base import TableDescriptionRepository +from dataherald.repositories.database_connections import DatabaseConnectionRepository +from dataherald.sql_database.base import SQLDatabase +from dataherald.sql_database.models.types import DatabaseConnection +from dataherald.types import DatabaseConnectionRequest +from dataherald.utils.encrypt import FernetEncrypt + + +class DatabaseConnectionService: + def __init__(self, scanner: Scanner, storage: DB): + self.scanner = scanner + self.storage = storage + + def get_current_schema(self, database_connection: DatabaseConnection) -> list[str]: + sql_database = SQLDatabase.get_sql_engine(database_connection, True) + inspector = inspect(sql_database.engine) + if inspector.default_schema_name and database_connection.dialect not in [ + "mssql", + "mysql", + "clickhouse", + "duckdb", + ]: + return [inspector.default_schema_name] + if database_connection.dialect == "bigquery": + pattern = r"([^:/]+)://([^/]+)/([^/]+)(\?[^/]+)" + match = re.match(pattern, str(sql_database.engine.url)) + if match: + return [match.group(3)] + elif database_connection.dialect == "databricks": + pattern = r"&schema=([^&]*)" + match = re.search(pattern, str(sql_database.engine.url)) + if match: + return [match.group(1)] + return ["default"] + + def remove_schema_in_uri(self, connection_uri: str, dialect: str) -> str: + if dialect in ["snowflake"]: + pattern = r"([^:/]+)://([^:]+):([^@]+)@([^:/]+)(?::(\d+))?/([^/]+)" + match = re.match(pattern, connection_uri) + if match: + return match.group(0) + if dialect in ["bigquery"]: + pattern = r"([^:/]+)://([^/]+)" + match = re.match(pattern, connection_uri) + if match: + return match.group(0) + elif dialect in ["databricks"]: + pattern = r"&schema=[^&]*" + return re.sub(pattern, "", connection_uri) + elif dialect in ["postgresql"]: + pattern = r"\?options=-csearch_path" r"=[^&]*" + return re.sub(pattern, "", connection_uri) + return connection_uri + + def add_schema_in_uri(self, connection_uri: str, schema: str, dialect: str) -> str: + connection_uri = self.remove_schema_in_uri(connection_uri, dialect) + if dialect in ["snowflake", "bigquery"]: + return f"{connection_uri}/{schema}" + if dialect in ["databricks"]: + return f"{connection_uri}&schema={schema}" + if dialect in ["postgresql"]: + return f"{connection_uri}?options=-csearch_path={schema}" + return connection_uri + + def create( + self, database_connection_request: DatabaseConnectionRequest + ) -> DatabaseConnection: + database_connection = DatabaseConnection( + alias=database_connection_request.alias, + connection_uri=database_connection_request.connection_uri.strip(), + schemas=database_connection_request.schemas, + path_to_credentials_file=database_connection_request.path_to_credentials_file, + llm_api_key=database_connection_request.llm_api_key, + use_ssh=database_connection_request.use_ssh, + ssh_settings=database_connection_request.ssh_settings, + file_storage=database_connection_request.file_storage, + metadata=database_connection_request.metadata, + ) + if not database_connection.schemas: + database_connection.schemas = self.get_current_schema(database_connection) + + schemas_and_tables = {} + fernet_encrypt = FernetEncrypt() + + if database_connection.schemas: + for schema in database_connection.schemas: + database_connection.connection_uri = fernet_encrypt.encrypt( + self.add_schema_in_uri( + database_connection_request.connection_uri.strip(), + schema, + str(database_connection.dialect), + ) + ) + sql_database = SQLDatabase.get_sql_engine(database_connection, True) + schemas_and_tables[schema] = sql_database.get_tables_and_views() + + # Connect db + db_connection_repository = DatabaseConnectionRepository(self.storage) + database_connection.connection_uri = fernet_encrypt.encrypt( + self.remove_schema_in_uri( + database_connection_request.connection_uri.strip(), + str(database_connection.dialect), + ) + ) + db_connection = db_connection_repository.insert(database_connection) + + scanner_repository = TableDescriptionRepository(self.storage) + # Add created tables + for schema, tables in schemas_and_tables.items(): + self.scanner.create_tables( + tables, str(db_connection.id), schema, scanner_repository + ) + return DatabaseConnection(**db_connection.dict()) diff --git a/dataherald/types.py b/dataherald/types.py index 44566756..85e62757 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -87,6 +87,7 @@ class DatabaseConnectionRequest(BaseModel): alias: str use_ssh: bool = False connection_uri: str + schemas: list[str] | None path_to_credentials_file: str | None llm_api_key: str | None ssh_settings: SSHSettings | None From 846010df94e7f5abaa0b368d6d7fe339c4e861a4 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Wed, 17 Apr 2024 17:09:01 -0600 Subject: [PATCH 02/14] DBs without schema should store None --- dataherald/sql_database/models/types.py | 2 +- dataherald/sql_database/services/database_connection.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/dataherald/sql_database/models/types.py b/dataherald/sql_database/models/types.py index 538da8dd..dc969fa6 100644 --- a/dataherald/sql_database/models/types.py +++ b/dataherald/sql_database/models/types.py @@ -106,7 +106,7 @@ class DatabaseConnection(BaseModel): @classmethod def get_dialect(cls, input_string): - pattern = r"([^:/]+):/+([^/]+)/?([^/]+)" + pattern = r"([^:/]+)://" match = re.match(pattern, input_string) if not match: raise InvalidURIFormatError(f"Invalid URI format: {input_string}") diff --git a/dataherald/sql_database/services/database_connection.py b/dataherald/sql_database/services/database_connection.py index 916c3839..e011ce3b 100644 --- a/dataherald/sql_database/services/database_connection.py +++ b/dataherald/sql_database/services/database_connection.py @@ -17,7 +17,9 @@ def __init__(self, scanner: Scanner, storage: DB): self.scanner = scanner self.storage = storage - def get_current_schema(self, database_connection: DatabaseConnection) -> list[str]: + def get_current_schema( + self, database_connection: DatabaseConnection + ) -> list[str] | None: sql_database = SQLDatabase.get_sql_engine(database_connection, True) inspector = inspect(sql_database.engine) if inspector.default_schema_name and database_connection.dialect not in [ @@ -37,7 +39,7 @@ def get_current_schema(self, database_connection: DatabaseConnection) -> list[st match = re.search(pattern, str(sql_database.engine.url)) if match: return [match.group(1)] - return ["default"] + return None def remove_schema_in_uri(self, connection_uri: str, dialect: str) -> str: if dialect in ["snowflake"]: @@ -99,6 +101,9 @@ def create( ) sql_database = SQLDatabase.get_sql_engine(database_connection, True) schemas_and_tables[schema] = sql_database.get_tables_and_views() + else: + sql_database = SQLDatabase.get_sql_engine(database_connection, True) + schemas_and_tables[None] = sql_database.get_tables_and_views() # Connect db db_connection_repository = DatabaseConnectionRepository(self.storage) From f6d617efcb1e20bafaae1aac804be5191b64f723 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Fri, 19 Apr 2024 16:41:56 -0600 Subject: [PATCH 03/14] Add ids in sync-schemas endpoint --- dataherald/api/fastapi.py | 61 ++++++++----------- dataherald/db_scanner/__init__.py | 3 +- dataherald/db_scanner/sqlalchemy.py | 23 +++---- .../services/database_connection.py | 14 +++++ dataherald/types.py | 13 +++- 5 files changed, 59 insertions(+), 55 deletions(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 3e36fded..ef7383d4 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -97,11 +97,10 @@ MAX_ROWS_TO_CREATE_CSV_FILE = 50 -def async_scanning(scanner, database, scanner_request, storage): +def async_scanning(scanner, database, table_descriptions, storage): scanner.scan( database, - scanner_request.db_connection_id, - scanner_request.table_names, + table_descriptions, TableDescriptionRepository(storage), QueryHistoryRepository(storage), ) @@ -133,43 +132,35 @@ def scan_db( self, scanner_request: ScannerRequest, background_tasks: BackgroundTasks ) -> list[TableDescriptionResponse]: """Takes a db_connection_id and scan all the tables columns""" - try: - db_connection_repository = DatabaseConnectionRepository(self.storage) + scanner_repository = TableDescriptionRepository(self.storage) + data = {} + for id in scanner_request.ids: + table_description = scanner_repository.find_by_id(id) + if not table_description: + raise Exception("Table description not found") + if table_description.schema_name not in data.keys(): + data[table_description.schema_name] = [] + data[table_description.schema_name].append(table_description) + db_connection_repository = DatabaseConnectionRepository(self.storage) + scanner = self.system.instance(Scanner) + database_connection_service = DatabaseConnectionService(scanner, self.storage) + for schema, table_descriptions in data.items(): db_connection = db_connection_repository.find_by_id( - scanner_request.db_connection_id + table_descriptions[0].db_connection_id ) - - if not db_connection: - raise DatabaseConnectionNotFoundError( - f"Database connection {scanner_request.db_connection_id} not found" - ) - - database = SQLDatabase.get_sql_engine(db_connection, True) - all_tables = database.get_tables_and_views() - - if scanner_request.table_names: - for table in scanner_request.table_names: - if table not in all_tables: - raise HTTPException( - status_code=404, - detail=f"Table named: {table} doesn't exist", - ) # noqa: B904 - else: - scanner_request.table_names = all_tables - - scanner = self.system.instance(Scanner) - rows = scanner.synchronizing( - scanner_request, - TableDescriptionRepository(self.storage), + database = database_connection_service.get_sql_database( + db_connection, schema ) - except Exception as e: - return error_response(e, scanner_request.dict(), "invalid_database_sync") - background_tasks.add_task( - async_scanning, scanner, database, scanner_request, self.storage - ) - return [TableDescriptionResponse(**row.dict()) for row in rows] + background_tasks.add_task( + async_scanning, scanner, database, table_descriptions, self.storage + ) + return [ + TableDescriptionResponse(**row.dict()) + for _, table_descriptions in data.items() + for row in table_descriptions + ] @override def create_database_connection( diff --git a/dataherald/db_scanner/__init__.py b/dataherald/db_scanner/__init__.py index 050a8870..8931730f 100644 --- a/dataherald/db_scanner/__init__.py +++ b/dataherald/db_scanner/__init__.py @@ -14,8 +14,7 @@ class Scanner(Component, ABC): def scan( self, db_engine: SQLDatabase, - db_connection_id: str, - table_names: list[str] | None, + table_descriptions: list[TableDescription], repository: TableDescriptionRepository, query_history_repository: QueryHistoryRepository, ) -> None: diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index 51cc6d94..7ebbbc22 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -278,8 +278,7 @@ def scan_single_table( def scan( self, db_engine: SQLDatabase, - db_connection_id: str, - table_names: list[str] | None, + table_descriptions: list[TableDescription], repository: TableDescriptionRepository, query_history_repository: QueryHistoryRepository, ) -> None: @@ -295,32 +294,24 @@ def scan( if db_engine.engine.dialect.name in services.keys(): scanner_service = services[db_engine.engine.dialect.name]() - inspector = inspect(db_engine.engine) + inspect(db_engine.engine) meta = MetaData(bind=db_engine.engine) MetaData.reflect(meta, views=True) - tables = inspector.get_table_names() + inspector.get_view_names() - if table_names: - table_names = [table.lower() for table in table_names] - tables = [ - table for table in tables if table and table.lower() in table_names - ] - if len(tables) == 0: - raise ValueError("No table found") - for table in tables: + for table in table_descriptions: try: self.scan_single_table( meta=meta, - table=table, + table=table.table_name, db_engine=db_engine, - db_connection_id=db_connection_id, + db_connection_id=table.db_connection_id, repository=repository, scanner_service=scanner_service, ) except Exception as e: repository.save_table_info( TableDescription( - db_connection_id=db_connection_id, + db_connection_id=table.db_connection_id, table_name=table, status=TableDescriptionStatus.FAILED.value, error_message=f"{e}", @@ -329,7 +320,7 @@ def scan( try: logger.info(f"Get logs table: {table}") query_history = scanner_service.get_logs( - table, db_engine, db_connection_id + table.table_name, db_engine, table.db_connection_id ) if len(query_history) > 0: for query in query_history: diff --git a/dataherald/sql_database/services/database_connection.py b/dataherald/sql_database/services/database_connection.py index e011ce3b..9f64ac04 100644 --- a/dataherald/sql_database/services/database_connection.py +++ b/dataherald/sql_database/services/database_connection.py @@ -17,6 +17,20 @@ def __init__(self, scanner: Scanner, storage: DB): self.scanner = scanner self.storage = storage + def get_sql_database( + self, database_connection: DatabaseConnection, schema: str = None + ) -> SQLDatabase: + fernet_encrypt = FernetEncrypt() + if schema: + database_connection.connection_uri = fernet_encrypt.encrypt( + self.add_schema_in_uri( + fernet_encrypt.decrypt(database_connection.connection_uri), + schema, + database_connection.dialect.value, + ) + ) + return SQLDatabase.get_sql_engine(database_connection, True) + def get_current_schema( self, database_connection: DatabaseConnection ) -> list[str] | None: diff --git a/dataherald/types.py b/dataherald/types.py index 85e62757..9f74a166 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -78,10 +78,19 @@ class SupportedDatabase(Enum): BIGQUERY = "BIGQUERY" -class ScannerRequest(DBConnectionValidation): - table_names: list[str] | None +class ScannerRequest(BaseModel): + ids: list[str] | None metadata: dict | None + @validator("ids") + def ids_validation(cls, ids: list = None): + try: + for id in ids: + ObjectId(id) + except InvalidId: + raise ValueError("Must be a valid ObjectId") # noqa: B904 + return ids + class DatabaseConnectionRequest(BaseModel): alias: str From 6d2094dd2f64e6b1d1e2627a7f5a516d1d025d97 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Fri, 19 Apr 2024 17:12:49 -0600 Subject: [PATCH 04/14] Support multi-schemas for refresh endpoint --- dataherald/api/fastapi.py | 24 ++++++++++++---- dataherald/db_scanner/__init__.py | 2 +- dataherald/db_scanner/sqlalchemy.py | 44 ++++++++++++++++------------- 3 files changed, 43 insertions(+), 27 deletions(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index ef7383d4..0f608416 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -195,18 +195,30 @@ def refresh_table_description( db_connection = db_connection_repository.find_by_id( refresh_table_description.db_connection_id ) - + scanner = self.system.instance(Scanner) + database_connection_service = DatabaseConnectionService(scanner, self.storage) try: - sql_database = SQLDatabase.get_sql_engine(db_connection, True) - tables = sql_database.get_tables_and_views() + data = {} + if db_connection.schemas: + for schema in db_connection.schemas: + sql_database = database_connection_service.get_sql_database( + db_connection, schema + ) + if schema not in data.keys(): + data[schema] = [] + data[schema] = sql_database.get_tables_and_views() + else: + sql_database = database_connection_service.get_sql_database( + db_connection + ) + data[None] = sql_database.get_tables_and_views() - # Get tables and views and create missing table-descriptions as NOT_SCANNED and update DEPRECATED scanner_repository = TableDescriptionRepository(self.storage) - scanner = self.system.instance(Scanner) + return [ TableDescriptionResponse(**record.dict()) for record in scanner.refresh_tables( - tables, str(db_connection.id), scanner_repository + data, str(db_connection.id), scanner_repository ) ] except Exception as e: diff --git a/dataherald/db_scanner/__init__.py b/dataherald/db_scanner/__init__.py index 8931730f..cd62dddd 100644 --- a/dataherald/db_scanner/__init__.py +++ b/dataherald/db_scanner/__init__.py @@ -42,7 +42,7 @@ def create_tables( @abstractmethod def refresh_tables( self, - tables: list[str], + schemas_and_tables: dict[str, list], db_connection_id: str, repository: TableDescriptionRepository, metadata: dict = None, diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index 7ebbbc22..6d3a7733 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -62,34 +62,38 @@ def create_tables( @override def refresh_tables( self, - tables: list[str], + schemas_and_tables: dict[str, list], db_connection_id: str, repository: TableDescriptionRepository, metadata: dict = None, ) -> list[TableDescription]: - stored_tables = repository.find_by({"db_connection_id": str(db_connection_id)}) - stored_tables_list = [table.table_name for table in stored_tables] - rows = [] - for table_description in stored_tables: - if table_description.table_name not in tables: - table_description.status = TableDescriptionStatus.DEPRECATED.value - rows.append(repository.save_table_info(table_description)) - else: - rows.append(TableDescription(**table_description.dict())) + for schema, tables in schemas_and_tables.items(): + stored_tables = repository.find_by( + {"db_connection_id": str(db_connection_id), "schema": schema} + ) + stored_tables_list = [table.table_name for table in stored_tables] - for table in tables: - if table not in stored_tables_list: - rows.append( - repository.save_table_info( - TableDescription( - db_connection_id=db_connection_id, - table_name=table, - status=TableDescriptionStatus.NOT_SCANNED.value, - metadata=metadata, + for table_description in stored_tables: + if table_description.table_name not in tables: + table_description.status = TableDescriptionStatus.DEPRECATED.value + rows.append(repository.save_table_info(table_description)) + else: + rows.append(TableDescription(**table_description.dict())) + + for table in tables: + if table not in stored_tables_list: + rows.append( + repository.save_table_info( + TableDescription( + db_connection_id=db_connection_id, + table_name=table, + status=TableDescriptionStatus.NOT_SCANNED.value, + metadata=metadata, + schema_name=schema, + ) ) ) - ) return rows @override From 80c360e089d6e75b3a0d2e4bda85a7758d22d6f6 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Mon, 22 Apr 2024 13:56:25 -0600 Subject: [PATCH 05/14] Add schema not support error exception --- .../services/database_connection.py | 17 +++++++++++++++++ dataherald/utils/error_codes.py | 1 + 2 files changed, 18 insertions(+) diff --git a/dataherald/sql_database/services/database_connection.py b/dataherald/sql_database/services/database_connection.py index 9f64ac04..0d7bf052 100644 --- a/dataherald/sql_database/services/database_connection.py +++ b/dataherald/sql_database/services/database_connection.py @@ -10,6 +10,11 @@ from dataherald.sql_database.models.types import DatabaseConnection from dataherald.types import DatabaseConnectionRequest from dataherald.utils.encrypt import FernetEncrypt +from dataherald.utils.error_codes import CustomError + + +class SchemaNotSupportedError(CustomError): + pass class DatabaseConnectionService: @@ -98,6 +103,18 @@ def create( file_storage=database_connection_request.file_storage, metadata=database_connection_request.metadata, ) + if database_connection.schemas and database_connection.dialect in [ + "redshift", + "awsathena", + "mssql", + "mysql", + "clickhouse", + "duckdb", + ]: + raise SchemaNotSupportedError( + "Schema not supported for this db", + description=f"The {database_connection.dialect} dialect doesn't support schemas", + ) if not database_connection.schemas: database_connection.schemas = self.get_current_schema(database_connection) diff --git a/dataherald/utils/error_codes.py b/dataherald/utils/error_codes.py index 6680feb5..3b508f72 100644 --- a/dataherald/utils/error_codes.py +++ b/dataherald/utils/error_codes.py @@ -19,6 +19,7 @@ "SQLGenerationNotFoundError": "sql_generation_not_found", "NLGenerationError": "nl_generation_not_created", "MalformedGoldenSQLError": "invalid_golden_sql", + "SchemaNotSupportedError": "schema_not_supported", } From e3800f86562bf4c6b24f1e512478fafdb56aa836 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Tue, 23 Apr 2024 10:47:39 -0600 Subject: [PATCH 06/14] Add documentation for multi-schemas --- README.md | 25 ++++++++++++++++++++++--- dataherald/tests/test_api.py | 21 --------------------- docs/api.create_database_connection.rst | 25 ++++++++++++++++++++++++- docs/api.get_table_description.rst | 1 + docs/api.list_database_connections.rst | 2 ++ docs/api.list_table_description.rst | 1 + docs/api.refresh_table_description.rst | 1 + docs/api.scan_table_description.rst | 8 ++++---- 8 files changed, 55 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index fcfdd1aa..c2a4b768 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,24 @@ curl -X 'POST' \ }' ``` +##### Connecting multi-schemas +You can connect many schemas using one db connection if you want to create SQL joins between schemas. +Currently only `BigQuery`, `Snowflake`, `Databricks` and `Postgres` support this feature. +To use multi-schemas instead of sending the `schema` in the `connection_uri` set it in the `schemas` param, like this: + +``` +curl -X 'POST' \ + '/api/v1/database-connections' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "alias": "my_db_alias", + "use_ssh": false, + "connection_uri": snowflake://:@-/", + "schemas": ["schema_1", "schema_2", ...] +}' +``` + ##### Connecting to supported Data warehouses and using SSH You can find the details on how to connect to the supported data warehouses in the [docs](https://dataherald.readthedocs.io/en/latest/api.create_database_connection.html) @@ -194,7 +212,8 @@ While only the Database scan part is required to start generating SQL, adding ve #### Scanning the Database The database scan is used to gather information about the database including table and column names and identifying low cardinality columns and their values to be stored in the context store and used in the prompts to the LLM. In addition, it retrieves logs, which consist of historical queries associated with each database table. These records are then stored within the query_history collection. The historical queries retrieved encompass data from the past three months and are grouped based on query and user. -db_connection_id is the id of the database connection you want to scan, which is returned when you create a database connection. +The db_connection_id param is the id of the database connection you want to scan, which is returned when you create a database connection. +The ids param is the table_description_id that you want to scan. You can trigger a scan of a database from the `POST /api/v1/table-descriptions/sync-schemas` endpoint. Example below @@ -205,11 +224,11 @@ curl -X 'POST' \ -H 'Content-Type: application/json' \ -d '{ "db_connection_id": "db_connection_id", - "table_names": ["table_name"] + "ids": ["", "", ...] }' ``` -Since the endpoint identifies low cardinality columns (and their values) it can take time to complete. Therefore while it is possible to trigger a scan on the entire DB by not specifying the `table_names`, we recommend against it for large databases. +Since the endpoint identifies low cardinality columns (and their values) it can take time to complete. #### Get logs per db connection Once a database was scanned you can use this endpoint to retrieve the tables logs diff --git a/dataherald/tests/test_api.py b/dataherald/tests/test_api.py index 5d086bc9..87869812 100644 --- a/dataherald/tests/test_api.py +++ b/dataherald/tests/test_api.py @@ -12,24 +12,3 @@ def test_heartbeat(): response = client.get("/api/v1/heartbeat") assert response.status_code == HTTP_200_CODE - - -def test_scan_all_tables(): - response = client.post( - "/api/v1/table-descriptions/sync-schemas", - json={"db_connection_id": "64dfa0e103f5134086f7090c"}, - ) - assert response.status_code == HTTP_201_CODE - - -def test_scan_one_table(): - try: - client.post( - "/api/v1/table-descriptions/sync-schemas", - json={ - "db_connection_id": "64dfa0e103f5134086f7090c", - "table_names": ["foo"], - }, - ) - except ValueError as e: - assert str(e) == "No table found" diff --git a/docs/api.create_database_connection.rst b/docs/api.create_database_connection.rst index 60e7ba77..c2fbf724 100644 --- a/docs/api.create_database_connection.rst +++ b/docs/api.create_database_connection.rst @@ -26,6 +26,9 @@ Once the database connection is established, it retrieves the table names and cr "alias": "string", "use_ssh": false, "connection_uri": "string", + "schemas": [ + "string" + ], "path_to_credentials_file": "string", "llm_api_key": "string", "ssh_settings": { @@ -189,7 +192,7 @@ Connections to supported Data warehouses ----------------------------------------- The format of the ``connection_uri`` parameter in the API call will depend on the data warehouse type you are connecting to. -You can find samples and how to generate them :ref:. +You can find samples and how to generate them below. Postgres ^^^^^^^^^^^^ @@ -324,3 +327,23 @@ Example:: "connection_uri": bigquery://v2-real-estate/K2 +**Connecting multi-schemas** + +You can connect many schemas using one db connection if you want to create SQL joins between schemas. +Currently only `BigQuery`, `Snowflake`, `Databricks` and `Postgres` support this feature. +To use multi-schemas instead of sending the `schema` in the `connection_uri` set it in the `schemas` param, like this: + +**Example** + +.. code-block:: rst + + curl -X 'POST' \ + '/api/v1/database-connections' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "alias": "my_db_alias_identifier", + "use_ssh": false, + "connection_uri": "snowflake://:@-/", + "schemas": ["foo", "bar"] + }' diff --git a/docs/api.get_table_description.rst b/docs/api.get_table_description.rst index 330e89c7..76b0f8dc 100644 --- a/docs/api.get_table_description.rst +++ b/docs/api.get_table_description.rst @@ -24,6 +24,7 @@ HTTP 200 code response "table_schema": "string", "status": "NOT_SCANNED | SYNCHRONIZING | DEPRECATED | SCANNED | FAILED" "error_message": "string", + "table_schema": "string", "columns": [ { "name": "string", diff --git a/docs/api.list_database_connections.rst b/docs/api.list_database_connections.rst index 1396ee23..d688d912 100644 --- a/docs/api.list_database_connections.rst +++ b/docs/api.list_database_connections.rst @@ -21,6 +21,7 @@ HTTP 200 code response "dialect": "databricks", "use_ssh": false, "connection_uri": "foooAABk91Q4wjoR2h07GR7_72BdQnxi8Rm6i_EjyS-mzz_o2c3RAWaEqnlUvkK5eGD5kUfE5xheyivl1Wfbk_EM7CgV4SvdLmOOt7FJV-3kG4zAbar=", + "schemas": null, "path_to_credentials_file": null, "llm_api_key": null, "ssh_settings": null @@ -31,6 +32,7 @@ HTTP 200 code response "dialect": "postgres", "use_ssh": true, "connection_uri": null, + "schemas": null, "path_to_credentials_file": "bar-LWxPdFcjQw9lU7CeK_2ELR3jGBq0G_uQ7E2rfPLk2RcFR4aDO9e2HmeAQtVpdvtrsQ_0zjsy9q7asdsadXExYJ0g==", "llm_api_key": "gAAAAABlCz5TeU0ym4hW3bf9u21dz7B9tlnttOGLRDt8gq2ykkblNvpp70ZjT9FeFcoyMv-Csvp3GNQfw66eYvQBrcBEPsLokkLO2Jc2DD-Q8Aw6g_8UahdOTxJdT4izA6MsiQrf7GGmYBGZqbqsjTdNmcq661wF9Q==", "ssh_settings": { diff --git a/docs/api.list_table_description.rst b/docs/api.list_table_description.rst index e257805d..9bff1318 100644 --- a/docs/api.list_table_description.rst +++ b/docs/api.list_table_description.rst @@ -33,6 +33,7 @@ HTTP 200 code response "table_schema": "string", "status": "NOT_SCANNED | SYNCHRONIZING | DEPRECATED | SCANNED | FAILED" "error_message": "string", + "table_schema": "string", "columns": [ { "name": "string", diff --git a/docs/api.refresh_table_description.rst b/docs/api.refresh_table_description.rst index 74c600c1..6e392e79 100644 --- a/docs/api.refresh_table_description.rst +++ b/docs/api.refresh_table_description.rst @@ -34,6 +34,7 @@ HTTP 201 code response "table_schema": "string", "status": "NOT_SCANNED | SYNCHRONIZING | DEPRECATED | SCANNED | FAILED" "error_message": "string", + "table_schema": "string", "columns": [ { "name": "string", diff --git a/docs/api.scan_table_description.rst b/docs/api.scan_table_description.rst index 1ccc9f8e..55488cc1 100644 --- a/docs/api.scan_table_description.rst +++ b/docs/api.scan_table_description.rst @@ -9,7 +9,7 @@ which consist of historical queries associated with each database table. These r query_history collection. The historical queries retrieved encompass data from the past three months and are grouped based on query and user. -It can scan all db tables or if you specify a `table_names` then It will only scan those tables. +The `ids` param is used to set the table description ids that you want to scan. The process is carried out through Background Tasks, ensuring that even if it operates slowly, taking several minutes, the HTTP response remains swift. @@ -23,7 +23,7 @@ Request this ``POST`` endpoint:: { "db_connection_id": "string", - "table_names": ["string"] # Optional + "ids": ["string"] } **Responses** @@ -36,7 +36,6 @@ HTTP 201 code response **Request example** -To scan all the tables in a db don't specify a `table_names` .. code-block:: rst @@ -45,5 +44,6 @@ To scan all the tables in a db don't specify a `table_names` -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ - "db_connection_id": "db_connection_id" + "db_connection_id": "db_connection_id", + "ids": ["14e52c5f7d6dc4bc510d6d27", "15e52c5f7d6dc4bc510d6d34"] }' From 5e76e53f2609c2288e0d4d1e533ab34f065d0f5c Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Tue, 23 Apr 2024 12:16:05 -0600 Subject: [PATCH 07/14] Fix sync schema method --- dataherald/api/fastapi.py | 10 +++++----- dataherald/db_scanner/sqlalchemy.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 0f608416..a59a4d4c 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -144,6 +144,10 @@ def scan_db( db_connection_repository = DatabaseConnectionRepository(self.storage) scanner = self.system.instance(Scanner) + rows = scanner.synchronizing( + scanner_request, + TableDescriptionRepository(self.storage), + ) database_connection_service = DatabaseConnectionService(scanner, self.storage) for schema, table_descriptions in data.items(): db_connection = db_connection_repository.find_by_id( @@ -156,11 +160,7 @@ def scan_db( background_tasks.add_task( async_scanning, scanner, database, table_descriptions, self.storage ) - return [ - TableDescriptionResponse(**row.dict()) - for _, table_descriptions in data.items() - for row in table_descriptions - ] + return [TableDescriptionResponse(**row.dict()) for row in rows] @override def create_database_connection( diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index 6d3a7733..72b26c0f 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -102,14 +102,14 @@ def synchronizing( scanner_request: ScannerRequest, repository: TableDescriptionRepository, ) -> list[TableDescription]: - # persist tables to be scanned rows = [] - for table in scanner_request.table_names: + for id in scanner_request.ids: + table_description = repository.find_by_id(id) rows.append( repository.save_table_info( TableDescription( - db_connection_id=scanner_request.db_connection_id, - table_name=table, + db_connection_id=table_description.db_connection_id, + table_name=table_description.table_name, status=TableDescriptionStatus.SYNCHRONIZING.value, metadata=scanner_request.metadata, ) From a6eea8a4a702997e757e011bdcd7b20843b080a4 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Tue, 23 Apr 2024 12:49:06 -0600 Subject: [PATCH 08/14] Sync-schemas endpoint let adding ids from different db connection --- dataherald/api/fastapi.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index a59a4d4c..82931f8d 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -138,9 +138,18 @@ def scan_db( table_description = scanner_repository.find_by_id(id) if not table_description: raise Exception("Table description not found") - if table_description.schema_name not in data.keys(): - data[table_description.schema_name] = [] - data[table_description.schema_name].append(table_description) + if table_description.db_connection_id not in data.keys(): + data[table_description.db_connection_id] = {} + if ( + table_description.schema_name + not in data[table_description.db_connection_id].keys() + ): + data[table_description.db_connection_id][ + table_description.schema_name + ] = [] + data[table_description.db_connection_id][ + table_description.schema_name + ].append(table_description) db_connection_repository = DatabaseConnectionRepository(self.storage) scanner = self.system.instance(Scanner) @@ -149,17 +158,16 @@ def scan_db( TableDescriptionRepository(self.storage), ) database_connection_service = DatabaseConnectionService(scanner, self.storage) - for schema, table_descriptions in data.items(): - db_connection = db_connection_repository.find_by_id( - table_descriptions[0].db_connection_id - ) - database = database_connection_service.get_sql_database( - db_connection, schema - ) + for db_connection_id, schemas_and_table_descriptions in data.items(): + for schema, table_descriptions in schemas_and_table_descriptions.items(): + db_connection = db_connection_repository.find_by_id(db_connection_id) + database = database_connection_service.get_sql_database( + db_connection, schema + ) - background_tasks.add_task( - async_scanning, scanner, database, table_descriptions, self.storage - ) + background_tasks.add_task( + async_scanning, scanner, database, table_descriptions, self.storage + ) return [TableDescriptionResponse(**row.dict()) for row in rows] @override From 8493438fb120400d35a5e6dbbcfe2598e3a753f5 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Wed, 24 Apr 2024 08:46:16 -0600 Subject: [PATCH 09/14] Fix refresh endpoint --- dataherald/db_scanner/sqlalchemy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index 72b26c0f..6511e32e 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -70,7 +70,7 @@ def refresh_tables( rows = [] for schema, tables in schemas_and_tables.items(): stored_tables = repository.find_by( - {"db_connection_id": str(db_connection_id), "schema": schema} + {"db_connection_id": str(db_connection_id), "schema_name": schema} ) stored_tables_list = [table.table_name for table in stored_tables] From ef9e4e628a634ed066d415d8eff3b03b7b36fc8d Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Wed, 24 Apr 2024 11:45:44 -0600 Subject: [PATCH 10/14] Fix table-description storage --- dataherald/db_scanner/repository/base.py | 1 + dataherald/db_scanner/sqlalchemy.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/dataherald/db_scanner/repository/base.py b/dataherald/db_scanner/repository/base.py index 7e6cf3da..c3656d44 100644 --- a/dataherald/db_scanner/repository/base.py +++ b/dataherald/db_scanner/repository/base.py @@ -59,6 +59,7 @@ def save_table_info(self, table_info: TableDescription) -> TableDescription: { "db_connection_id": table_info_dict["db_connection_id"], "table_name": table_info_dict["table_name"], + "schema_name": table_info_dict["schema_name"], }, table_info_dict, ) diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index 6511e32e..2d3da7ca 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -112,6 +112,7 @@ def synchronizing( table_name=table_description.table_name, status=TableDescriptionStatus.SYNCHRONIZING.value, metadata=scanner_request.metadata, + schema_name=table_description.schema_name, ) ) ) @@ -241,6 +242,7 @@ def scan_single_table( db_connection_id: str, repository: TableDescriptionRepository, scanner_service: AbstractScanner, + schema: str | None = None, ) -> TableDescription: print(f"Scanning table: {table}") inspector = inspect(db_engine.engine) @@ -273,6 +275,7 @@ def scan_single_table( last_schema_sync=datetime.now(), error_message="", status=TableDescriptionStatus.SCANNED.value, + schema_name=schema, ) repository.save_table_info(object) @@ -311,6 +314,7 @@ def scan( db_connection_id=table.db_connection_id, repository=repository, scanner_service=scanner_service, + schema=table.schema_name, ) except Exception as e: repository.save_table_info( @@ -319,6 +323,7 @@ def scan( table_name=table, status=TableDescriptionStatus.FAILED.value, error_message=f"{e}", + schema_name=table.schema_name, ) ) try: From 4eb7a3e02deeefa1736ec427417d221ef9360ff1 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Wed, 24 Apr 2024 12:56:24 -0600 Subject: [PATCH 11/14] Fix schema_name filter in table-description repository --- dataherald/db_scanner/repository/base.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dataherald/db_scanner/repository/base.py b/dataherald/db_scanner/repository/base.py index c3656d44..870263cb 100644 --- a/dataherald/db_scanner/repository/base.py +++ b/dataherald/db_scanner/repository/base.py @@ -53,14 +53,18 @@ def save_table_info(self, table_info: TableDescription) -> TableDescription: table_info_dict = { k: v for k, v in table_info_dict.items() if v is not None and v != [] } + + query = { + "db_connection_id": table_info_dict["db_connection_id"], + "table_name": table_info_dict["table_name"], + } + if "schema_name" in table_info_dict: + query["schema_name"] = table_info_dict["schema_name"] + table_info.id = str( self.storage.update_or_create( DB_COLLECTION, - { - "db_connection_id": table_info_dict["db_connection_id"], - "table_name": table_info_dict["table_name"], - "schema_name": table_info_dict["schema_name"], - }, + query, table_info_dict, ) ) From b24d9c838d8394bffad92b6479b3ef219e4fa837 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Thu, 25 Apr 2024 13:05:25 -0400 Subject: [PATCH 12/14] DH-5735/add support for multiple schemas for agents --- dataherald/api/types/requests.py | 1 + dataherald/api/types/responses.py | 1 + dataherald/services/prompts.py | 8 ++ dataherald/sql_generator/__init__.py | 17 ++- .../dataherald_finetuning_agent.py | 66 ++++++++--- .../sql_generator/dataherald_sqlagent.py | 111 +++++++++++++----- dataherald/types.py | 1 + docs/api.create_prompt.rst | 8 +- docs/api.create_prompt_sql_generation.rst | 3 + ...te_prompt_sql_generation_nl_generation.rst | 3 + docs/api.get_prompt.rst | 8 +- docs/api.list_prompts.rst | 5 +- 12 files changed, 178 insertions(+), 54 deletions(-) diff --git a/dataherald/api/types/requests.py b/dataherald/api/types/requests.py index fcb90348..108c775e 100644 --- a/dataherald/api/types/requests.py +++ b/dataherald/api/types/requests.py @@ -6,6 +6,7 @@ class PromptRequest(BaseModel): text: str db_connection_id: str + schemas: list[str] | None metadata: dict | None diff --git a/dataherald/api/types/responses.py b/dataherald/api/types/responses.py index 7edd7c28..ebacf601 100644 --- a/dataherald/api/types/responses.py +++ b/dataherald/api/types/responses.py @@ -25,6 +25,7 @@ def created_at_as_string(cls, v): class PromptResponse(BaseResponse): text: str db_connection_id: str + schemas: list[str] | None class SQLGenerationResponse(BaseResponse): diff --git a/dataherald/services/prompts.py b/dataherald/services/prompts.py index f42c272b..b7392b4a 100644 --- a/dataherald/services/prompts.py +++ b/dataherald/services/prompts.py @@ -4,6 +4,7 @@ DatabaseConnectionRepository, ) from dataherald.repositories.prompts import PromptNotFoundError, PromptRepository +from dataherald.sql_database.services.database_connection import SchemaNotSupportedError from dataherald.types import Prompt @@ -22,9 +23,16 @@ def create(self, prompt_request: PromptRequest) -> Prompt: f"Database connection {prompt_request.db_connection_id} not found" ) + if not db_connection.schemas and prompt_request.schemas: + raise SchemaNotSupportedError( + "Schema not supported for this db", + description=f"The {db_connection.dialect} dialect doesn't support schemas", + ) + prompt = Prompt( text=prompt_request.text, db_connection_id=prompt_request.db_connection_id, + schemas=prompt_request.schemas, metadata=prompt_request.metadata, ) return self.prompt_repository.insert(prompt) diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 16fc94d4..e997920e 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -11,10 +11,10 @@ from langchain.agents.agent import AgentExecutor from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, LLMResult -from langchain.schema.messages import BaseMessage from langchain_community.callbacks import get_openai_callback from dataherald.config import Component, System +from dataherald.db_scanner.models.types import TableDescription from dataherald.model.chat_model import ChatModel from dataherald.repositories.sql_generations import ( SQLGenerationRepository, @@ -62,6 +62,21 @@ def remove_markdown(self, query: str) -> str: return matches[0].strip() return query + @staticmethod + def get_table_schema(table_name: str, db_scan: List[TableDescription]) -> str: + for table in db_scan: + if table.table_name == table_name: + return table.schema_name + return "" + + @staticmethod + def filter_tables_by_schema( + db_scan: List[TableDescription], prompt: Prompt + ) -> List[TableDescription]: + if prompt.schemas: + return [table for table in db_scan if table.schema_name in prompt.schemas] + return db_scan + def format_sql_query_intermediate_steps(self, step: str) -> str: pattern = r"```sql(.*?)```" diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index 5fe0ed4d..5e2b423c 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -190,12 +190,20 @@ def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[s tables = Parser(example["sql"]).tables except Exception as e: logger.error(f"Error parsing SQL: {str(e)}") - most_similar_tables.update(tables) - df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True) + for table in tables: + found_tables = df[df.table_name == table] + for _, row in found_tables.iterrows(): + most_similar_tables.add((row["schema_name"], row["table_name"])) + df.drop( + df[ + df.table_name.isin([table[1] for table in most_similar_tables]) + ].index, + inplace=True, + ) return most_similar_tables @catch_exceptions() - def _run( + def _run( # noqa: PLR0912 self, user_question: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 @@ -214,9 +222,12 @@ def _run( table_rep = f"Table {table.table_name} contain columns: [{col_rep}], this tables has: {table.description}" else: table_rep = f"Table {table.table_name} contain columns: [{col_rep}]" - table_representations.append([table.table_name, table_rep]) + table_representations.append( + [table.schema_name, table.table_name, table_rep] + ) df = pd.DataFrame( - table_representations, columns=["table_name", "table_representation"] + table_representations, + columns=["schema_name", "table_name", "table_representation"], ) df["table_embedding"] = self.get_docs_embedding(df.table_representation) df["similarities"] = df.table_embedding.apply( @@ -227,12 +238,20 @@ def _run( most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): - table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' + if row["schema_name"] is not None: + table_name = row["schema_name"] + "." + row["table_name"] + else: + table_name = row["table_name"] + table_relevance += ( + f'Table: `{table_name}`, relevance score: {row["similarities"]}\n' + ) if len(most_similar_tables) > 0: for table in most_similar_tables: - table_relevance += ( - f"Table: `{table}`, relevance score: {max(df['similarities'])}\n" - ) + if table[0] is not None: + table_name = table[0] + "." + table[1] + else: + table_name = table[1] + table_relevance += f"Table: `{table_name}`, relevance score: {max(df['similarities'])}\n" return table_relevance async def _arun( @@ -358,27 +377,32 @@ class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): db_scan: List[TableDescription] @catch_exceptions() - def _run( + def _run( # noqa: C901 self, table_names: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: """Get the schema for tables in a comma-separated list.""" table_names_list = table_names.split(", ") - table_names_list = [ - replace_unprocessable_characters(table_name) - for table_name in table_names_list - ] + processed_table_names = [] + for table in table_names_list: + formatted_table = replace_unprocessable_characters(table) + if "." in formatted_table: + processed_table_names.append(formatted_table.split(".")[1]) + else: + processed_table_names.append(formatted_table) tables_schema = "" for table in self.db_scan: - if table.table_name in table_names_list: + if table.table_name in processed_table_names: tables_schema += "```sql\n" tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: - descriptions.append( - f"Table `{table.table_name}`: {table.description}\n" - ) + if table.schema_name: + table_name = f"{table.schema_name}.{table.table_name}" + else: + table_name = table.table_name + descriptions.append(f"Table `{table_name}`: {table.description}\n") for column in table.columns: if column.description is not None: descriptions.append( @@ -555,6 +579,9 @@ def generate_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) few_shot_examples, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=5 ) @@ -658,6 +685,9 @@ def stream_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) _, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=1 ) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 899e6e87..b686bf2f 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -10,7 +10,6 @@ import numpy as np import openai import pandas as pd -import sqlalchemy from google.api_core.exceptions import GoogleAPIError from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.base import BaseToolkit @@ -27,9 +26,7 @@ from overrides import override from pydantic import BaseModel, Field from sql_metadata import Parser -from sqlalchemy import MetaData from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.sql import func from dataherald.context_store import ContextStore from dataherald.db import DB @@ -254,12 +251,20 @@ def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[s tables = Parser(example["sql"]).tables except Exception as e: logger.error(f"Error parsing SQL: {str(e)}") - most_similar_tables.update(tables) - df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True) + for table in tables: + found_tables = df[df.table_name == table] + for _, row in found_tables.iterrows(): + most_similar_tables.add((row["schema_name"], row["table_name"])) + df.drop( + df[ + df.table_name.isin([table[1] for table in most_similar_tables]) + ].index, + inplace=True, + ) return most_similar_tables @catch_exceptions() - def _run( + def _run( # noqa: PLR0912 self, user_question: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 @@ -278,9 +283,12 @@ def _run( table_rep = f"Table {table.table_name} contain columns: [{col_rep}], this tables has: {table.description}" else: table_rep = f"Table {table.table_name} contain columns: [{col_rep}]" - table_representations.append([table.table_name, table_rep]) + table_representations.append( + [table.schema_name, table.table_name, table_rep] + ) df = pd.DataFrame( - table_representations, columns=["table_name", "table_representation"] + table_representations, + columns=["schema_name", "table_name", "table_representation"], ) df["table_embedding"] = self.get_docs_embedding(df.table_representation) df["similarities"] = df.table_embedding.apply( @@ -291,12 +299,20 @@ def _run( most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): - table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' + if row["schema_name"] is not None: + table_name = row["schema_name"] + "." + row["table_name"] + else: + table_name = row["table_name"] + table_relevance += ( + f'Table: `{table_name}`, relevance score: {row["similarities"]}\n' + ) if len(most_similar_tables) > 0: for table in most_similar_tables: - table_relevance += ( - f"Table: `{table}`, relevance score: {max(df['similarities'])}\n" - ) + if table[0] is not None: + table_name = table[0] + "." + table[1] + else: + table_name = table[1] + table_relevance += f"Table: `{table_name}`, relevance score: {max(df['similarities'])}\n" return table_relevance async def _arun( @@ -318,6 +334,8 @@ class ColumnEntityChecker(BaseSQLDatabaseTool, BaseTool): Example Input: table1 -> column2, entity """ + db_scan: List[TableDescription] + is_multiple_schema: bool def find_similar_strings( self, input_list: List[tuple], target_string: str, threshold=0.4 @@ -341,21 +359,25 @@ def _run( try: schema, entity = tool_input.split(",") table_name, column_name = schema.split("->") + table_name = replace_unprocessable_characters(table_name) + column_name = replace_unprocessable_characters(column_name).strip() + if "." not in table_name and self.is_multiple_schema: + raise Exception( + "Table name should be in the format schema_name.table_name" + ) except ValueError: return "Invalid input format, use following format: table_name -> column_name, entity (entity should be a string without ',')" search_pattern = f"%{entity.strip().lower()}%" - meta = MetaData(bind=self.db.engine) - table = sqlalchemy.Table(table_name.strip(), meta, autoload=True) + search_query = f"SELECT DISTINCT {column_name} FROM {table_name} WHERE {column_name} ILIKE :search_pattern" # noqa: S608 try: - search_query = sqlalchemy.select( - [func.distinct(table.c[column_name.strip()])] - ).where(func.lower(table.c[column_name.strip()]).like(search_pattern)) - search_results = self.db.engine.execute(search_query).fetchall() + search_results = self.db.engine.execute( + search_query, {"search_pattern": search_pattern} + ).fetchall() search_results = search_results[:25] except SQLAlchemyError: search_results = [] - distinct_query = sqlalchemy.select( - [func.distinct(table.c[column_name.strip()])] + distinct_query = ( + f"SELECT DISTINCT {column_name} FROM {table_name}" # noqa: S608 ) results = self.db.engine.execute(distinct_query).fetchall() results = self.find_similar_strings(results, entity) @@ -392,26 +414,31 @@ class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): db_scan: List[TableDescription] @catch_exceptions() - def _run( + def _run( # noqa: C901 self, table_names: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: """Get the schema for tables in a comma-separated list.""" table_names_list = table_names.split(", ") - table_names_list = [ - replace_unprocessable_characters(table_name) - for table_name in table_names_list - ] + processed_table_names = [] + for table in table_names_list: + formatted_table = replace_unprocessable_characters(table) + if "." in formatted_table: + processed_table_names.append(formatted_table.split(".")[1]) + else: + processed_table_names.append(formatted_table) tables_schema = "```sql\n" for table in self.db_scan: - if table.table_name in table_names_list: + if table.table_name in processed_table_names: tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: - descriptions.append( - f"Table `{table.table_name}`: {table.description}\n" - ) + if table.schema_name: + table_name = f"{table.schema_name}.{table.table_name}" + else: + table_name = table.table_name + descriptions.append(f"Table `{table_name}`: {table.description}\n") for column in table.columns: if column.description is not None: descriptions.append( @@ -446,7 +473,7 @@ class InfoRelevantColumns(BaseSQLDatabaseTool, BaseTool): db_scan: List[TableDescription] @catch_exceptions() - def _run( # noqa: C901 + def _run( # noqa: C901, PLR0912 self, column_names: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 @@ -457,6 +484,8 @@ def _run( # noqa: C901 for item in items_list: if " -> " in item: table_name, column_name = item.split(" -> ") + if "." in table_name: + table_name = table_name.split(".")[1] table_name = replace_unprocessable_characters(table_name) column_name = replace_unprocessable_characters(column_name) found = False @@ -474,7 +503,11 @@ def _run( # noqa: C901 for row in table.examples: col_info += row[column_name] + ", " col_info = col_info[:-2] - column_full_info += f"Table: {table_name}, column: {column_name}, additional info: {col_info}\n" + if table.schema_name: + schema_table = f"{table.schema_name}.{table.table_name}" + else: + schema_table = table.table_name + column_full_info += f"Table: {schema_table}, column: {column_name}, additional info: {col_info}\n" else: return "Malformed input, input should be in the following format Example Input: table1 -> column1, table1 -> column2, table2 -> column1" # noqa: E501 if not found: @@ -539,6 +572,7 @@ class SQLDatabaseToolkit(BaseToolkit): instructions: List[dict] | None = Field(exclude=True, default=None) db_scan: List[TableDescription] = Field(exclude=True) embedding: OpenAIEmbeddings = Field(exclude=True) + is_multiple_schema: bool = False @property def dialect(self) -> str: @@ -579,7 +613,12 @@ def get_tools(self) -> List[BaseTool]: db=self.db, context=self.context, db_scan=self.db_scan ) tools.append(info_relevant_tool) - column_sample_tool = ColumnEntityChecker(db=self.db, context=self.context) + column_sample_tool = ColumnEntityChecker( + db=self.db, + context=self.context, + db_scan=self.db_scan, + is_multiple_schema=self.is_multiple_schema, + ) tools.append(column_sample_tool) if self.few_shot_examples is not None: get_fewshot_examples_tool = GetFewShotExamples( @@ -700,6 +739,9 @@ def generate_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) few_shot_examples, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=self.max_number_of_examples ) @@ -716,6 +758,7 @@ def generate_response( context=context, few_shot_examples=new_fewshot_examples, instructions=instructions, + is_multiple_schema=True if user_prompt.schemas else False, db_scan=db_scan, embedding=OpenAIEmbeddings( openai_api_key=database_connection.decrypt_api_key(), @@ -802,6 +845,9 @@ def stream_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) few_shot_examples, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=self.max_number_of_examples ) @@ -818,6 +864,7 @@ def stream_response( context=[{}], few_shot_examples=new_fewshot_examples, instructions=instructions, + is_multiple_schema=True if user_prompt.schemas else False, db_scan=db_scan, embedding=OpenAIEmbeddings( openai_api_key=database_connection.decrypt_api_key(), diff --git a/dataherald/types.py b/dataherald/types.py index 9f74a166..8977fbca 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -178,6 +178,7 @@ class Prompt(BaseModel): id: str | None = None text: str db_connection_id: str + schemas: list[str] | None created_at: datetime = Field(default_factory=datetime.now) metadata: dict | None diff --git a/docs/api.create_prompt.rst b/docs/api.create_prompt.rst index e0bf7982..71867776 100644 --- a/docs/api.create_prompt.rst +++ b/docs/api.create_prompt.rst @@ -17,6 +17,9 @@ Request this ``POST`` endpoint to create a finetuning job:: { "text": "string", "db_connection_id": "string", + "schemas": [ + "string" + ], "metadata": {} } @@ -31,7 +34,10 @@ HTTP 201 code response "metadata": {}, "created_at": "string", "text": "string", - "db_connection_id": "string" + "db_connection_id": "string", + "schemas": [ + "string" + ] } **Request example** diff --git a/docs/api.create_prompt_sql_generation.rst b/docs/api.create_prompt_sql_generation.rst index 3226e7d9..5b486e17 100644 --- a/docs/api.create_prompt_sql_generation.rst +++ b/docs/api.create_prompt_sql_generation.rst @@ -45,6 +45,9 @@ Request this ``POST`` endpoint to create a SQL query and a NL response for a giv "prompt": { "text": "string", "db_connection_id": "string", + "schemas": [ + "string" + ], "metadata": {} } } diff --git a/docs/api.create_prompt_sql_generation_nl_generation.rst b/docs/api.create_prompt_sql_generation_nl_generation.rst index 8b2f62af..7851b9c7 100644 --- a/docs/api.create_prompt_sql_generation_nl_generation.rst +++ b/docs/api.create_prompt_sql_generation_nl_generation.rst @@ -61,6 +61,9 @@ Request this ``POST`` endpoint to create a SQL query and a NL response for a giv "prompt": { "text": "string", "db_connection_id": "string", + "schemas": [ + "string" + ], "metadata": {} } } diff --git a/docs/api.get_prompt.rst b/docs/api.get_prompt.rst index fa02f0be..f86a9043 100644 --- a/docs/api.get_prompt.rst +++ b/docs/api.get_prompt.rst @@ -16,7 +16,10 @@ HTTP 200 code response "metadata": {}, "created_at": "string", "text": "string", - "db_connection_id": "string" + "db_connection_id": "string", + "schemas": [ + "string" + ] } **Request example** @@ -37,4 +40,7 @@ HTTP 200 code response "created_at": "2024-01-03 15:40:27.624000+00:00", "text": "What is the expected total full year sales in 2023?", "db_connection_id": "659435a5453d359345834" + "schemas": [ + "public" + ] } \ No newline at end of file diff --git a/docs/api.list_prompts.rst b/docs/api.list_prompts.rst index 9f434a21..2122f117 100644 --- a/docs/api.list_prompts.rst +++ b/docs/api.list_prompts.rst @@ -29,7 +29,10 @@ HTTP 201 code response "metadata": {}, "created_at": "string", "text": "string", - "db_connection_id": "string" + "db_connection_id": "string", + "schemas": [ + "string" + ] } ] From b4f57c4001d292c429b4ea80b6f30a5f57f1a888 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Fri, 26 Apr 2024 10:17:23 -0400 Subject: [PATCH 13/14] DH-5766/adding the validation to raise exception for queries without schema in multiple schema setting --- dataherald/context_store/default.py | 13 +++++++++++++ dataherald/utils/sql_utils.py | 11 +++++++++++ 2 files changed, 24 insertions(+) create mode 100644 dataherald/utils/sql_utils.py diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index df0f40bd..34a1d8e8 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -13,6 +13,7 @@ from dataherald.repositories.golden_sqls import GoldenSQLRepository from dataherald.repositories.instructions import InstructionRepository from dataherald.types import GoldenSQL, GoldenSQLRequest, Prompt +from dataherald.utils.sql_utils import extract_the_schemas_from_sql logger = logging.getLogger(__name__) @@ -86,6 +87,18 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL f"Database connection not found, {record.db_connection_id}" ) + if db_connection.schemas: + schema_not_found = True + used_schemas = extract_the_schemas_from_sql(record.sql) + for schema in db_connection.schemas: + if schema in used_schemas: + schema_not_found = False + break + if schema_not_found: + raise MalformedGoldenSQLError( + f"SQL {record.sql} does not contain any of the schemas {db_connection.schemas}" + ) + prompt_text = record.prompt_text golden_sql = GoldenSQL( prompt_text=prompt_text, diff --git a/dataherald/utils/sql_utils.py b/dataherald/utils/sql_utils.py new file mode 100644 index 00000000..8f3e3a80 --- /dev/null +++ b/dataherald/utils/sql_utils.py @@ -0,0 +1,11 @@ +from sql_metadata import Parser + + +def extract_the_schemas_from_sql(sql): + table_names = Parser(sql).tables + schemas = [] + for table_name in table_names: + if "." in table_name: + schema = table_name.split(".")[0] + schemas.append(schema.strip()) + return schemas From 2b63aa8be14329cf6d20f44dd3578bd8ec5458c0 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Fri, 26 Apr 2024 11:55:09 -0400 Subject: [PATCH 14/14] DH-5765/add support multiple schema for finetuning --- dataherald/api/fastapi.py | 11 +++++-- dataherald/finetuning/openai_finetuning.py | 7 ++++ dataherald/types.py | 2 ++ dataherald/utils/sql_utils.py | 38 +++++++++++++++++++++- 4 files changed, 55 insertions(+), 3 deletions(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 82931f8d..23e3976c 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -91,6 +91,10 @@ ) from dataherald.utils.encrypt import FernetEncrypt from dataherald.utils.error_codes import error_response, stream_error_response +from dataherald.utils.sql_utils import ( + filter_golden_records_based_on_schema, + validate_finetuning_schema, +) logger = logging.getLogger(__name__) @@ -564,7 +568,6 @@ def create_finetuning_job( ) -> Finetuning: try: db_connection_repository = DatabaseConnectionRepository(self.storage) - db_connection = db_connection_repository.find_by_id( fine_tuning_request.db_connection_id ) @@ -572,7 +575,7 @@ def create_finetuning_job( raise DatabaseConnectionNotFoundError( f"Database connection not found, {fine_tuning_request.db_connection_id}" ) - + validate_finetuning_schema(fine_tuning_request, db_connection) golden_sqls_repository = GoldenSQLRepository(self.storage) golden_sqls = [] if fine_tuning_request.golden_sqls: @@ -593,6 +596,9 @@ def create_finetuning_job( raise GoldenSQLNotFoundError( f"No golden sqls found for db_connection: {fine_tuning_request.db_connection_id}" ) + golden_sqls = filter_golden_records_based_on_schema( + golden_sqls, fine_tuning_request.schemas + ) default_base_llm = BaseLLM( model_provider="openai", model_name="gpt-3.5-turbo-1106", @@ -606,6 +612,7 @@ def create_finetuning_job( model = model_repository.insert( Finetuning( db_connection_id=fine_tuning_request.db_connection_id, + schemas=fine_tuning_request.schemas, alias=fine_tuning_request.alias if fine_tuning_request.alias else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}", diff --git a/dataherald/finetuning/openai_finetuning.py b/dataherald/finetuning/openai_finetuning.py index 5c553579..95b0f10f 100644 --- a/dataherald/finetuning/openai_finetuning.py +++ b/dataherald/finetuning/openai_finetuning.py @@ -69,6 +69,12 @@ def map_finetuning_status(status: str) -> str: return FineTuningStatus.QUEUED.value return mapped_statuses[status] + @staticmethod + def _filter_tables_by_schema(db_scan: List[TableDescription], schemas: List[str]): + if schemas: + return [table for table in db_scan if table.schema_name in schemas] + return db_scan + def format_columns( self, table: TableDescription, top_k: int = CATEGORICAL_COLUMNS_THRESHOLD ) -> str: @@ -197,6 +203,7 @@ def create_fintuning_dataset(self): "status": TableDescriptionStatus.SCANNED.value, } ) + db_scan = self._filter_tables_by_schema(db_scan, self.fine_tuning_model.schemas) golden_sqls_repository = GoldenSQLRepository(self.storage) finetuning_dataset_path = f"tmp/{str(uuid.uuid4())}.jsonl" model_repository = FinetuningsRepository(self.storage) diff --git a/dataherald/types.py b/dataherald/types.py index 8977fbca..de113e72 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -150,6 +150,7 @@ class Finetuning(BaseModel): id: str | None = None alias: str | None = None db_connection_id: str | None = None + schemas: list[str] | None status: str = "QUEUED" error: str | None = None base_llm: BaseLLM | None = None @@ -163,6 +164,7 @@ class Finetuning(BaseModel): class FineTuningRequest(BaseModel): db_connection_id: str + schemas: list[str] | None alias: str | None = None base_llm: BaseLLM | None = None golden_sqls: list[str] | None = None diff --git a/dataherald/utils/sql_utils.py b/dataherald/utils/sql_utils.py index 8f3e3a80..a1ab3680 100644 --- a/dataherald/utils/sql_utils.py +++ b/dataherald/utils/sql_utils.py @@ -1,7 +1,11 @@ from sql_metadata import Parser +from dataherald.sql_database.models.types import DatabaseConnection +from dataherald.sql_database.services.database_connection import SchemaNotSupportedError +from dataherald.types import FineTuningRequest, GoldenSQL -def extract_the_schemas_from_sql(sql): + +def extract_the_schemas_from_sql(sql: str) -> list[str]: table_names = Parser(sql).tables schemas = [] for table_name in table_names: @@ -9,3 +13,35 @@ def extract_the_schemas_from_sql(sql): schema = table_name.split(".")[0] schemas.append(schema.strip()) return schemas + + +def filter_golden_records_based_on_schema( + golden_sqls: list[GoldenSQL], schemas: list[str] +) -> list[GoldenSQL]: + filtered_records = [] + if not schemas: + return golden_sqls + for record in golden_sqls: + used_schemas = extract_the_schemas_from_sql(record.sql) + for schema in schemas: + if schema in used_schemas: + filtered_records.append(record) + break + return filtered_records + + +def validate_finetuning_schema( + finetuning_request: FineTuningRequest, db_connection: DatabaseConnection +): + if finetuning_request.schemas: + if not db_connection.schemas: + raise SchemaNotSupportedError( + "Schema not supported for this db", + description=f"The {db_connection.id} db doesn't have schemas", + ) + for schema in finetuning_request.schemas: + if schema not in db_connection.schemas: + raise SchemaNotSupportedError( + f"Schema {schema} not supported for this db", + description=f"The {db_connection.dialect} dialect doesn't support schema {schema}", + )