diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index ae432e59..10062573 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -37,7 +37,6 @@ class SqlAlchemyScanner(Scanner): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.scanner_service: AbstractScanner = None @override def create_tables( @@ -137,7 +136,12 @@ def get_table_examples( return examples_dict def get_processed_column( # noqa: PLR0911 - self, meta: MetaData, table: str, column: dict, db_engine: SQLDatabase + self, + meta: MetaData, + table: str, + column: dict, + db_engine: SQLDatabase, + scanner_service: AbstractScanner, ) -> ColumnDetail: dynamic_meta_table = meta.tables[table] @@ -155,7 +159,7 @@ def get_processed_column( # noqa: PLR0911 data_type=str(column["type"]), low_cardinality=False, ) - category_values = self.scanner_service.cardinality_values( + category_values = scanner_service.cardinality_values( dynamic_meta_table.c[column["name"]], db_engine ) if category_values: @@ -229,6 +233,7 @@ def scan_single_table( db_engine: SQLDatabase, db_connection_id: str, repository: TableDescriptionRepository, + scanner_service: AbstractScanner, ) -> TableDescription: print(f"Scanning table: {table}") inspector = inspect(db_engine.engine) @@ -240,7 +245,11 @@ def scan_single_table( print(f"Scanning column: {column['name']}") table_columns.append( self.get_processed_column( - meta=meta, table=table, column=column, db_engine=db_engine + meta=meta, + table=table, + column=column, + db_engine=db_engine, + scanner_service=scanner_service, ) ) @@ -278,9 +287,9 @@ def scan( "pymssql": SqlServerScanner, "http": ClickHouseScanner, } - self.scanner_service = BaseScanner() + scanner_service = BaseScanner() if db_engine.engine.driver in services.keys(): - self.scanner_service = services[db_engine.engine.driver]() + scanner_service = services[db_engine.engine.driver]() inspector = inspect(db_engine.engine) meta = MetaData(bind=db_engine.engine) @@ -302,6 +311,7 @@ def scan( db_engine=db_engine, db_connection_id=db_connection_id, repository=repository, + scanner_service=scanner_service, ) except Exception as e: repository.save_table_info( @@ -314,7 +324,7 @@ def scan( ) try: logger.info(f"Get logs table: {table}") - query_history = self.scanner_service.get_logs( + query_history = scanner_service.get_logs( table, db_engine, db_connection_id ) if len(query_history) > 0: