diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 187dfa7b..cca85643 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -6,7 +6,6 @@ from bson import json_util from fastapi import BackgroundTasks, HTTPException from overrides import override -from sqlalchemy import MetaData, inspect from dataherald.api import API from dataherald.api.types import Query @@ -87,6 +86,22 @@ def scan_db( ) scanner = self.system.instance(Scanner) + all_tables = scanner.get_all_tables_and_views(database) + if scanner_request.table_names: + for table in scanner_request.table_names: + if table not in all_tables: + raise HTTPException( + status_code=404, detail=f"Table named: {table} doesn't exist" + ) # noqa: B904 + else: + scanner_request.table_names = all_tables + + scanner.synchronizing( + scanner_request.table_names, + scanner_request.db_connection_id, + DBScannerRepository(self.storage), + ) + background_tasks.add_task( async_scanning, scanner, database, scanner_request, self.storage ) @@ -230,10 +245,9 @@ def list_table_descriptions( db_connection_repository = DatabaseConnectionRepository(self.storage) db_connection = db_connection_repository.find_by_id(db_connection_id) database = SQLDatabase.get_sql_engine(db_connection) - inspector = inspect(database.engine) - meta = MetaData(bind=database.engine) - MetaData.reflect(meta, views=True) - all_tables = inspector.get_table_names() + inspector.get_view_names() + + scanner = self.system.instance(Scanner) + all_tables = scanner.get_all_tables_and_views(database) for table_description in table_descriptions: if table_description.table_name not in all_tables: diff --git a/dataherald/db_scanner/__init__.py b/dataherald/db_scanner/__init__.py index 6308a023..4626f1e4 100644 --- a/dataherald/db_scanner/__init__.py +++ b/dataherald/db_scanner/__init__.py @@ -1,6 +1,5 @@ """Base class that all scanner classes inherit from.""" from abc import ABC, abstractmethod -from typing import Any, Union from dataherald.config import Component from dataherald.db_scanner.repository.base import DBScannerRepository @@ -17,3 +16,13 @@ def scan( repository: DBScannerRepository, ) -> None: """ "Scan a db""" + + @abstractmethod + def synchronizing( + self, tables: list[str], db_connection_id: str, repository: DBScannerRepository + ) -> None: + """ "Update table_description status""" + + @abstractmethod + def get_all_tables_and_views(self, database: SQLDatabase) -> list[str]: + """ "Retrieve all tables and views""" diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index 0d90388b..4f8e0d54 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -22,6 +22,27 @@ class SqlAlchemyScanner(Scanner): + @override + def synchronizing( + self, tables: list[str], db_connection_id: str, repository: DBScannerRepository + ) -> None: + # persist tables to be scanned + for table in tables: + repository.save_table_info( + TableSchemaDetail( + db_connection_id=db_connection_id, + table_name=table, + status=TableDescriptionStatus.SYNCHRONIZING.value, + ) + ) + + @override + def get_all_tables_and_views(self, database: SQLDatabase) -> list[str]: + inspector = inspect(database.engine) + meta = MetaData(bind=database.engine) + MetaData.reflect(meta, views=True) + return inspector.get_table_names() + inspector.get_view_names() + def get_table_examples( self, meta: MetaData, db_engine: SQLDatabase, table: str, rows_number: int = 3 ) -> List[Any]: @@ -150,7 +171,7 @@ def scan_single_table( meta=meta, db_engine=db_engine, table=table, rows_number=3 ), last_schema_sync=datetime.now(), - status="SYNCHRONIZED", + status=TableDescriptionStatus.SYNCHRONIZED.value, ) repository.save_table_info(object) @@ -176,16 +197,6 @@ def scan( if len(tables) == 0: raise ValueError("No table found") - # persist tables to be scanned - for table in tables: - repository.save_table_info( - TableSchemaDetail( - db_connection_id=db_connection_id, - table_name=table, - status=TableDescriptionStatus.SYNCHRONIZING.value, - ) - ) - for table in tables: try: self.scan_single_table(