Skip to content

Commit

Permalink
DH-4739 When sync-schemas endpoint is executed set the status SYNCHRO…
Browse files Browse the repository at this point in the history
…NIZING
  • Loading branch information
jcjc712 committed Sep 27, 2023
1 parent 9e39613 commit be321ea
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
24 changes: 19 additions & 5 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from bson import json_util
from fastapi import BackgroundTasks, HTTPException
from overrides import override
from sqlalchemy import MetaData, inspect

from dataherald.api import API
from dataherald.api.types import Query
Expand Down Expand Up @@ -87,6 +86,22 @@ def scan_db(
)

scanner = self.system.instance(Scanner)
all_tables = scanner.get_all_tables_and_views(database)
if scanner_request.table_names:
for table in scanner_request.table_names:
if table not in all_tables:
raise HTTPException(
status_code=404, detail=f"Table named: {table} doesn't exist"
) # noqa: B904
else:
scanner_request.table_names = all_tables

scanner.synchronizing(
scanner_request.table_names,
scanner_request.db_connection_id,
DBScannerRepository(self.storage),
)

background_tasks.add_task(
async_scanning, scanner, database, scanner_request, self.storage
)
Expand Down Expand Up @@ -230,10 +245,9 @@ def list_table_descriptions(
db_connection_repository = DatabaseConnectionRepository(self.storage)
db_connection = db_connection_repository.find_by_id(db_connection_id)
database = SQLDatabase.get_sql_engine(db_connection)
inspector = inspect(database.engine)
meta = MetaData(bind=database.engine)
MetaData.reflect(meta, views=True)
all_tables = inspector.get_table_names() + inspector.get_view_names()

scanner = self.system.instance(Scanner)
all_tables = scanner.get_all_tables_and_views(database)

for table_description in table_descriptions:
if table_description.table_name not in all_tables:
Expand Down
11 changes: 10 additions & 1 deletion dataherald/db_scanner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Base class that all scanner classes inherit from."""
from abc import ABC, abstractmethod
from typing import Any, Union

from dataherald.config import Component
from dataherald.db_scanner.repository.base import DBScannerRepository
Expand All @@ -17,3 +16,13 @@ def scan(
repository: DBScannerRepository,
) -> None:
""" "Scan a db"""

@abstractmethod
def synchronizing(
self, tables: list[str], db_connection_id: str, repository: DBScannerRepository
) -> None:
""" "Update table_description status"""

@abstractmethod
def get_all_tables_and_views(self, database: SQLDatabase) -> list[str]:
""" "Retrieve all tables and views"""
33 changes: 22 additions & 11 deletions dataherald/db_scanner/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,27 @@


class SqlAlchemyScanner(Scanner):
@override
def synchronizing(
self, tables: list[str], db_connection_id: str, repository: DBScannerRepository
) -> None:
# persist tables to be scanned
for table in tables:
repository.save_table_info(
TableSchemaDetail(
db_connection_id=db_connection_id,
table_name=table,
status=TableDescriptionStatus.SYNCHRONIZING.value,
)
)

@override
def get_all_tables_and_views(self, database: SQLDatabase) -> list[str]:
inspector = inspect(database.engine)
meta = MetaData(bind=database.engine)
MetaData.reflect(meta, views=True)
return inspector.get_table_names() + inspector.get_view_names()

def get_table_examples(
self, meta: MetaData, db_engine: SQLDatabase, table: str, rows_number: int = 3
) -> List[Any]:
Expand Down Expand Up @@ -150,7 +171,7 @@ def scan_single_table(
meta=meta, db_engine=db_engine, table=table, rows_number=3
),
last_schema_sync=datetime.now(),
status="SYNCHRONIZED",
status=TableDescriptionStatus.SYNCHRONIZED.value,
)

repository.save_table_info(object)
Expand All @@ -176,16 +197,6 @@ def scan(
if len(tables) == 0:
raise ValueError("No table found")

# persist tables to be scanned
for table in tables:
repository.save_table_info(
TableSchemaDetail(
db_connection_id=db_connection_id,
table_name=table,
status=TableDescriptionStatus.SYNCHRONIZING.value,
)
)

for table in tables:
try:
self.scan_single_table(
Expand Down

0 comments on commit be321ea

Please sign in to comment.