diff --git a/services/engine/dataherald/api/fastapi.py b/services/engine/dataherald/api/fastapi.py index e9edbd57..0b5ebf98 100644 --- a/services/engine/dataherald/api/fastapi.py +++ b/services/engine/dataherald/api/fastapi.py @@ -252,34 +252,9 @@ def update_database_connection( database_connection_request: DatabaseConnectionRequest, ) -> DatabaseConnectionResponse: try: - db_connection_repository = DatabaseConnectionRepository(self.storage) - db_connection = db_connection_repository.find_by_id(db_connection_id) - if not db_connection: - raise DatabaseConnectionNotFoundError( - f"Database connection {db_connection_id} not found" - ) - - db_connection = DatabaseConnection( - id=db_connection_id, - alias=database_connection_request.alias, - connection_uri=database_connection_request.connection_uri.strip(), - path_to_credentials_file=database_connection_request.path_to_credentials_file, - llm_api_key=database_connection_request.llm_api_key, - use_ssh=database_connection_request.use_ssh, - ssh_settings=database_connection_request.ssh_settings, - file_storage=database_connection_request.file_storage, - metadata=database_connection_request.metadata, - ) - - sql_database = SQLDatabase.get_sql_engine(db_connection, True) - - # Get tables and views and create missing table-descriptions as NOT_SCANNED and update DEPRECATED - scanner_repository = TableDescriptionRepository(self.storage) scanner = self.system.instance(Scanner) - - tables = sql_database.get_tables_and_views() - db_connection = db_connection_repository.update(db_connection) - scanner.refresh_tables(tables, str(db_connection.id), scanner_repository) + db_connection_service = DatabaseConnectionService(scanner, self.storage) + db_connection = db_connection_service.update(db_connection_id, database_connection_request) except Exception as e: # Encrypt sensible values fernet_encrypt = FernetEncrypt() diff --git a/services/engine/dataherald/sql_database/services/database_connection.py b/services/engine/dataherald/sql_database/services/database_connection.py index 0d7bf052..4ff08ed2 100644 --- a/services/engine/dataherald/sql_database/services/database_connection.py +++ b/services/engine/dataherald/sql_database/services/database_connection.py @@ -89,20 +89,7 @@ def add_schema_in_uri(self, connection_uri: str, schema: str, dialect: str) -> s return f"{connection_uri}?options=-csearch_path={schema}" return connection_uri - def create( - self, database_connection_request: DatabaseConnectionRequest - ) -> DatabaseConnection: - database_connection = DatabaseConnection( - alias=database_connection_request.alias, - connection_uri=database_connection_request.connection_uri.strip(), - schemas=database_connection_request.schemas, - path_to_credentials_file=database_connection_request.path_to_credentials_file, - llm_api_key=database_connection_request.llm_api_key, - use_ssh=database_connection_request.use_ssh, - ssh_settings=database_connection_request.ssh_settings, - file_storage=database_connection_request.file_storage, - metadata=database_connection_request.metadata, - ) + def get_schemas_and_tables(self, database_connection: DatabaseConnection) -> dict[str, list]: if database_connection.schemas and database_connection.dialect in [ "redshift", "awsathena", @@ -121,11 +108,13 @@ def create( schemas_and_tables = {} fernet_encrypt = FernetEncrypt() + connection_uri = fernet_encrypt.decrypt(database_connection.connection_uri) + if database_connection.schemas: for schema in database_connection.schemas: database_connection.connection_uri = fernet_encrypt.encrypt( self.add_schema_in_uri( - database_connection_request.connection_uri.strip(), + connection_uri.strip(), schema, str(database_connection.dialect), ) @@ -136,14 +125,34 @@ def create( sql_database = SQLDatabase.get_sql_engine(database_connection, True) schemas_and_tables[None] = sql_database.get_tables_and_views() - # Connect db - db_connection_repository = DatabaseConnectionRepository(self.storage) database_connection.connection_uri = fernet_encrypt.encrypt( self.remove_schema_in_uri( - database_connection_request.connection_uri.strip(), + connection_uri.strip(), str(database_connection.dialect), ) ) + + return schemas_and_tables + + def create( + self, database_connection_request: DatabaseConnectionRequest + ) -> DatabaseConnection: + database_connection = DatabaseConnection( + alias=database_connection_request.alias, + connection_uri=database_connection_request.connection_uri.strip(), + schemas=database_connection_request.schemas, + path_to_credentials_file=database_connection_request.path_to_credentials_file, + llm_api_key=database_connection_request.llm_api_key, + use_ssh=database_connection_request.use_ssh, + ssh_settings=database_connection_request.ssh_settings, + file_storage=database_connection_request.file_storage, + metadata=database_connection_request.metadata, + ) + + schemas_and_tables = self.get_schemas_and_tables(database_connection) + + # Connect db + db_connection_repository = DatabaseConnectionRepository(self.storage) db_connection = db_connection_repository.insert(database_connection) scanner_repository = TableDescriptionRepository(self.storage) @@ -153,3 +162,40 @@ def create( tables, str(db_connection.id), schema, scanner_repository ) return DatabaseConnection(**db_connection.dict()) + + def update( + self, + db_connection_id: str, + database_connection_request: DatabaseConnectionRequest + ) -> DatabaseConnection: + db_connection_repository = DatabaseConnectionRepository(self.storage) + database_connection = db_connection_repository.find_by_id(db_connection_id) + if not database_connection: + raise DatabaseConnectionNotFoundError( + f"Database connection {db_connection_id} not found" + ) + + database_connection = DatabaseConnection( + id=db_connection_id, + alias=database_connection_request.alias, + connection_uri=database_connection_request.connection_uri.strip(), + schemas=database_connection_request.schemas, + path_to_credentials_file=database_connection_request.path_to_credentials_file, + llm_api_key=database_connection_request.llm_api_key, + use_ssh=database_connection_request.use_ssh, + ssh_settings=database_connection_request.ssh_settings, + file_storage=database_connection_request.file_storage, + metadata=database_connection_request.metadata, + ) + + schemas_and_tables = self.get_schemas_and_tables(database_connection) + + # Connect db + db_connection = db_connection_repository.update(database_connection) + + scanner_repository = TableDescriptionRepository(self.storage) + # Refresh tables + self.scanner.refresh_tables( + schemas_and_tables, str(db_connection.id), scanner_repository + ) + return DatabaseConnection(**db_connection.dict())