diff --git a/dataherald/db_scanner/repository/base.py b/dataherald/db_scanner/repository/base.py index 00e0b6eb..f69a5972 100644 --- a/dataherald/db_scanner/repository/base.py +++ b/dataherald/db_scanner/repository/base.py @@ -32,8 +32,8 @@ def get_table_info( return TableSchemaDetail(**row) return None - def get_all_tables_by_db(self, db_connection_id: str) -> List[TableSchemaDetail]: - rows = self.storage.find(DB_COLLECTION, {"db_connection_id": db_connection_id}) + def get_all_tables_by_db(self, query: dict) -> List[TableSchemaDetail]: + rows = self.storage.find(DB_COLLECTION, query) tables = [] for row in rows: row["id"] = row["_id"] diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index f4a3355c..cbec9ca5 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -581,8 +581,10 @@ def generate_response( ) repository = DBScannerRepository(storage) db_scan = repository.get_all_tables_by_db( - db_connection_id=database_connection.id, - status=TableDescriptionStatus.SYNCHRONIZED.value, + { + "db_connection_id": str(database_connection.id), + "status": TableDescriptionStatus.SYNCHRONIZED.value, + } ) if not db_scan: raise ValueError("No scanned tables found for database")