From 92a8f476cc1e71b9981eab436a42f8c1f93aec96 Mon Sep 17 00:00:00 2001 From: Andrew Snare Date: Fri, 6 Sep 2024 16:12:28 +0200 Subject: [PATCH] Refactor: `CrawlerBase` and `PermissionsManager` snapshotting pattern (#2402) ## Changes This PR: - Refactors the existing permissions crawler so that it uses the same fetch/crawl/snapshot pattern as the rest of the crawlers. - Removes the permissions crawler's bespoke `.cleanup()` implementation in favour of the `.reset()` method already available on the base class. - Some dead code was removed on the permissions crawler: `.load_for_all()` - Specifies `.snapshot()` as an interface (on `CrawlerBase`) that all crawlers need to support. ### Linked issues Progresses #2074, by laying the ground for crawlers to support refreshing. Note that this PR and #2392 have a behaviour conflict that will trigger a failure in `test_runtime_crawl_permissions` if either is merged: resolving this is trivial. ### Functionality - modified existing workflow: `assessment` ### Tests - [X] updated unit tests - [X] updated integration tests --- src/databricks/labs/ucx/assessment/azure.py | 3 - .../labs/ucx/assessment/clusters.py | 6 -- .../labs/ucx/assessment/init_scripts.py | 3 - src/databricks/labs/ucx/assessment/jobs.py | 6 -- .../labs/ucx/assessment/pipelines.py | 3 - .../labs/ucx/assessment/workflows.py | 4 +- src/databricks/labs/ucx/framework/crawlers.py | 24 ++++++- .../labs/ucx/hive_metastore/grants.py | 8 +-- .../labs/ucx/hive_metastore/locations.py | 28 ++++---- .../labs/ucx/hive_metastore/mapping.py | 4 +- .../ucx/hive_metastore/migration_status.py | 3 - .../labs/ucx/hive_metastore/table_size.py | 13 +--- .../labs/ucx/hive_metastore/tables.py | 11 +--- .../labs/ucx/hive_metastore/udfs.py | 11 +--- .../labs/ucx/recon/migration_recon.py | 3 - .../labs/ucx/workspace_access/generic.py | 3 - .../labs/ucx/workspace_access/groups.py | 12 +--- .../labs/ucx/workspace_access/manager.py | 52 +++------------ .../test_permissions_manager.py | 39 ++++++++--- .../workspace_access/test_workflows.py | 6 +- tests/unit/assessment/test_workflows.py | 2 +- tests/unit/framework/test_crawlers.py | 48 ++++++++++---- tests/unit/workspace_access/test_manager.py | 65 +++++-------------- 23 files changed, 146 insertions(+), 211 deletions(-) diff --git a/src/databricks/labs/ucx/assessment/azure.py b/src/databricks/labs/ucx/assessment/azure.py index ff3765f4b8..540bbbcb90 100644 --- a/src/databricks/labs/ucx/assessment/azure.py +++ b/src/databricks/labs/ucx/assessment/azure.py @@ -44,9 +44,6 @@ def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema): super().__init__(sbe, "hive_metastore", schema, "azure_service_principals", AzureServicePrincipalInfo) self._ws = ws - def snapshot(self) -> Iterable[AzureServicePrincipalInfo]: - return self._snapshot(self._try_fetch, self._crawl) - def _try_fetch(self) -> Iterable[AzureServicePrincipalInfo]: for row in self._fetch(f"SELECT * FROM {self._catalog}.{self._schema}.{self._table}"): yield AzureServicePrincipalInfo(*row) diff --git a/src/databricks/labs/ucx/assessment/clusters.py b/src/databricks/labs/ucx/assessment/clusters.py index 9660fa9676..02badb64ec 100644 --- a/src/databricks/labs/ucx/assessment/clusters.py +++ b/src/databricks/labs/ucx/assessment/clusters.py @@ -174,9 +174,6 @@ def _assess_clusters(self, all_clusters): cluster_info.failures = json.dumps(failures) yield cluster_info - def snapshot(self) -> Iterable[ClusterInfo]: - return self._snapshot(self._try_fetch, self._crawl) - def _try_fetch(self) -> Iterable[ClusterInfo]: for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield ClusterInfo(*row) @@ -229,9 +226,6 @@ def _assess_policies(self, all_policices) -> Iterable[PolicyInfo]: policy_info.failures = json.dumps(failures) yield policy_info - def snapshot(self) -> Iterable[PolicyInfo]: - return self._snapshot(self._try_fetch, self._crawl) - def _try_fetch(self) -> Iterable[PolicyInfo]: for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield PolicyInfo(*row) diff --git a/src/databricks/labs/ucx/assessment/init_scripts.py b/src/databricks/labs/ucx/assessment/init_scripts.py index 4502a23824..909015b678 100644 --- a/src/databricks/labs/ucx/assessment/init_scripts.py +++ b/src/databricks/labs/ucx/assessment/init_scripts.py @@ -80,9 +80,6 @@ def _assess_global_init_scripts(self, all_global_init_scripts): global_init_script_info.success = 0 yield global_init_script_info - def snapshot(self) -> Iterable[GlobalInitScriptInfo]: - return self._snapshot(self._try_fetch, self._crawl) - def _try_fetch(self) -> Iterable[GlobalInitScriptInfo]: for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield GlobalInitScriptInfo(*row) diff --git a/src/databricks/labs/ucx/assessment/jobs.py b/src/databricks/labs/ucx/assessment/jobs.py index 345c9f821a..d5b77d68e0 100644 --- a/src/databricks/labs/ucx/assessment/jobs.py +++ b/src/databricks/labs/ucx/assessment/jobs.py @@ -128,9 +128,6 @@ def _prepare(all_jobs) -> tuple[dict[int, set[str]], dict[int, JobInfo]]: ) return job_assessment, job_details - def snapshot(self) -> Iterable[JobInfo]: - return self._snapshot(self._try_fetch, self._crawl) - def _try_fetch(self) -> Iterable[JobInfo]: for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield JobInfo(*row) @@ -166,9 +163,6 @@ def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema: str, num_days_h self._ws = ws self._num_days_history = num_days_history - def snapshot(self) -> Iterable[SubmitRunInfo]: - return self._snapshot(self._try_fetch, self._crawl) - @staticmethod def _dt_to_ms(date_time: datetime): return int(date_time.timestamp() * 1000) diff --git a/src/databricks/labs/ucx/assessment/pipelines.py b/src/databricks/labs/ucx/assessment/pipelines.py index e9843f3e4f..8421e53084 100644 --- a/src/databricks/labs/ucx/assessment/pipelines.py +++ b/src/databricks/labs/ucx/assessment/pipelines.py @@ -70,9 +70,6 @@ def _pipeline_clusters(self, clusters, failures): if cluster.init_scripts: failures.extend(self._check_cluster_init_script(cluster.init_scripts, "pipeline cluster")) - def snapshot(self) -> Iterable[PipelineInfo]: - return self._snapshot(self._try_fetch, self._crawl) - def _try_fetch(self) -> Iterable[PipelineInfo]: for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield PipelineInfo(*row) diff --git a/src/databricks/labs/ucx/assessment/workflows.py b/src/databricks/labs/ucx/assessment/workflows.py index a08b3ceed2..842afdb222 100644 --- a/src/databricks/labs/ucx/assessment/workflows.py +++ b/src/databricks/labs/ucx/assessment/workflows.py @@ -168,8 +168,8 @@ def crawl_permissions(self, ctx: RuntimeContext): This is the first step for the _group migration_ process, which is continued in the `migrate-groups` workflow. This step includes preparing Legacy Table ACLs for local group migration.""" permission_manager = ctx.permission_manager - permission_manager.cleanup() - permission_manager.inventorize_permissions() + permission_manager.reset() + permission_manager.snapshot() @job_task def crawl_groups(self, ctx: RuntimeContext): diff --git a/src/databricks/labs/ucx/framework/crawlers.py b/src/databricks/labs/ucx/framework/crawlers.py index 9a43fc0055..4828e1fbd2 100644 --- a/src/databricks/labs/ucx/framework/crawlers.py +++ b/src/databricks/labs/ucx/framework/crawlers.py @@ -1,4 +1,5 @@ import logging +from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from typing import ClassVar, Generic, Protocol, TypeVar @@ -19,7 +20,7 @@ class DataclassInstance(Protocol): ResultFn = Callable[[], Iterable[Result]] -class CrawlerBase(Generic[Result]): +class CrawlerBase(ABC, Generic[Result]): def __init__(self, backend: SqlBackend, catalog: str, schema: str, table: str, klass: type[Result]): """ Initializes a CrawlerBase instance. @@ -90,6 +91,25 @@ def _try_valid(cls, name: str | None): return None return cls._valid(name) + def snapshot(self) -> Iterable[Result]: + return self._snapshot(self._try_fetch, self._crawl) + + @abstractmethod + def _try_fetch(self) -> Iterable[Result]: + """Fetch existing data that has (previously) been crawled by this crawler. + + Returns: + Iterable[Result]: The data that has already been crawled. + """ + + @abstractmethod + def _crawl(self) -> Iterable[Result]: + """Perform the (potentially slow) crawling necessary to capture the current state of the environment. + + Returns: + Iterable[Result]: Records that capture the results of crawling the environment. + """ + def _snapshot(self, fetcher: ResultFn, loader: ResultFn) -> list[Result]: """ Tries to load dataset of records with `fetcher` function, otherwise automatically creates @@ -105,7 +125,7 @@ def _snapshot(self, fetcher: ResultFn, loader: ResultFn) -> list[Result]: re-raised. Returns: - list[any]: A list of data records, either fetched or loaded. + list[Result]: A list of data records, either fetched or loaded. """ logger.debug(f"[{self.full_name}] fetching {self._table} inventory") try: diff --git a/src/databricks/labs/ucx/hive_metastore/grants.py b/src/databricks/labs/ucx/hive_metastore/grants.py index 62db42e751..b3a0653956 100644 --- a/src/databricks/labs/ucx/hive_metastore/grants.py +++ b/src/databricks/labs/ucx/hive_metastore/grants.py @@ -207,13 +207,13 @@ def __init__(self, tc: TablesCrawler, udf: UdfsCrawler, include_databases: list[ def snapshot(self) -> Iterable[Grant]: try: - return self._snapshot(partial(self._try_load), partial(self._crawl)) + return super().snapshot() except Exception as e: # pylint: disable=broad-exception-caught log_fn = logger.warning if CLUSTER_WITHOUT_ACL_FRAGMENT in repr(e) else logger.error log_fn(f"Couldn't fetch grants snapshot: {e}") return [] - def _try_load(self): + def _try_fetch(self): for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield Grant(*row) @@ -585,7 +585,7 @@ def __init__( self._compute_locations = cluster_locations def get_interactive_cluster_grants(self) -> list[Grant]: - tables = self._tables_crawler.snapshot() + tables = list(self._tables_crawler.snapshot()) mounts = list(self._mounts_crawler.snapshot()) grants: set[Grant] = set() @@ -794,7 +794,7 @@ def __init__( def migrate_acls(self, *, target_catalog: str | None = None, hms_fed: bool = False) -> None: workspace_name = self._workspace_info.current() - tables = self._table_crawler.snapshot() + tables = list(self._table_crawler.snapshot()) if not tables: logger.info("No tables found to acl") return diff --git a/src/databricks/labs/ucx/hive_metastore/locations.py b/src/databricks/labs/ucx/hive_metastore/locations.py index b3e409096b..632adb3977 100644 --- a/src/databricks/labs/ucx/hive_metastore/locations.py +++ b/src/databricks/labs/ucx/hive_metastore/locations.py @@ -201,7 +201,7 @@ def _add_jdbc_location(self, external_locations, location, table): if not dupe: external_locations.append(ExternalLocation(jdbc_location, 1)) - def _external_location_list(self) -> Iterable[ExternalLocation]: + def _crawl(self) -> Iterable[ExternalLocation]: tables_identifier = escape_sql_identifier(f"{self._catalog}.{self._schema}.tables") tables = list( self._backend.fetch( @@ -211,9 +211,6 @@ def _external_location_list(self) -> Iterable[ExternalLocation]: mounts = Mounts(self._backend, self._ws, self._schema).snapshot() return self._external_locations(list(tables), list(mounts)) - def snapshot(self) -> Iterable[ExternalLocation]: - return self._snapshot(self._try_fetch, self._external_location_list) - def _try_fetch(self) -> Iterable[ExternalLocation]: for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield ExternalLocation(*row) @@ -320,15 +317,12 @@ def _deduplicate_mounts(self, mounts: list) -> list: deduplicated_mounts.append(obj) return deduplicated_mounts - def _list_mounts(self) -> Iterable[Mount]: + def _crawl(self) -> Iterable[Mount]: mounts = [] for mount_point, source, _ in self._dbutils.fs.mounts(): mounts.append(Mount(mount_point, source)) return self._deduplicate_mounts(mounts) - def snapshot(self) -> Iterable[Mount]: - return self._snapshot(self._try_fetch, self._list_mounts) - def _try_fetch(self) -> Iterable[Mount]: for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield Mount(*row) @@ -366,21 +360,25 @@ def __init__( self._fiter_paths = irrelevant_patterns def snapshot(self) -> list[Table]: + updated_records = self._crawl() + self._overwrite_records(updated_records) + return updated_records + + def _crawl(self) -> list[Table]: logger.debug(f"[{self.full_name}] fetching {self._table} inventory") cached_results = [] try: - cached_results = list(self._try_load()) + cached_results = list(self._try_fetch()) except NotFound: pass table_paths = self._get_tables_paths_from_assessment(cached_results) logger.debug(f"[{self.full_name}] crawling new batch for {self._table}") - loaded_records = list(self._crawl(table_paths)) - if len(cached_results) > 0: - loaded_records = loaded_records + cached_results - self._overwrite_records(loaded_records) + loaded_records = list(self._crawl_tables(table_paths)) + if cached_results: + loaded_records = [*loaded_records, *cached_results] return loaded_records - def _try_load(self) -> Iterable[Table]: + def _try_fetch(self) -> Iterable[Table]: """Tries to load table information from the database or throws TABLE_OR_VIEW_NOT_FOUND error""" for row in self._fetch( f"SELECT * FROM {escape_sql_identifier(self.full_name)} WHERE NOT STARTSWITH(database, '{self.TABLE_IN_MOUNT_DB}')" @@ -399,7 +397,7 @@ def _overwrite_records(self, items: Sequence[Table]): logger.debug(f"[{self.full_name}] found {len(items)} new records for {self._table}") self._backend.save_table(self.full_name, items, Table, mode="overwrite") - def _crawl(self, table_paths_from_assessment: dict[str, str]): + def _crawl_tables(self, table_paths_from_assessment: dict[str, str]): all_mounts = self._mounts_crawler.snapshot() all_tables = [] for mount in all_mounts: diff --git a/src/databricks/labs/ucx/hive_metastore/mapping.py b/src/databricks/labs/ucx/hive_metastore/mapping.py index 74a7bdbb65..655172e388 100644 --- a/src/databricks/labs/ucx/hive_metastore/mapping.py +++ b/src/databricks/labs/ucx/hive_metastore/mapping.py @@ -96,8 +96,8 @@ def __init__( self._recon_tolerance_percent = recon_tolerance_percent def current_tables(self, tables: TablesCrawler, workspace_name: str, catalog_name: str): - tables_snapshot = tables.snapshot() - if len(tables_snapshot) == 0: + tables_snapshot = list(tables.snapshot()) + if not tables_snapshot: msg = "No tables found. Please run: databricks labs ucx ensure-assessment-run" raise ValueError(msg) for table in tables_snapshot: diff --git a/src/databricks/labs/ucx/hive_metastore/migration_status.py b/src/databricks/labs/ucx/hive_metastore/migration_status.py index 3d6b71979a..5cdf0ceb21 100644 --- a/src/databricks/labs/ucx/hive_metastore/migration_status.py +++ b/src/databricks/labs/ucx/hive_metastore/migration_status.py @@ -74,9 +74,6 @@ def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema, table_crawler: self._ws = ws self._table_crawler = table_crawler - def snapshot(self) -> Iterable[MigrationStatus]: - return self._snapshot(self._try_fetch, self._crawl) - def index(self) -> MigrationIndex: return MigrationIndex(list(self.snapshot())) diff --git a/src/databricks/labs/ucx/hive_metastore/table_size.py b/src/databricks/labs/ucx/hive_metastore/table_size.py index 5cf8209cd9..0106fbab48 100644 --- a/src/databricks/labs/ucx/hive_metastore/table_size.py +++ b/src/databricks/labs/ucx/hive_metastore/table_size.py @@ -1,7 +1,6 @@ import logging from collections.abc import Iterable from dataclasses import dataclass -from functools import partial from databricks.labs.lsql.backends import SqlBackend @@ -54,21 +53,11 @@ def _crawl(self) -> Iterable[TableSize]: catalog=table.catalog, database=table.database, name=table.name, size_in_bytes=size_in_bytes ) - def _try_load(self) -> Iterable[TableSize]: + def _try_fetch(self) -> Iterable[TableSize]: """Tries to load table information from the database or throws TABLE_OR_VIEW_NOT_FOUND error""" for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield TableSize(*row) - def snapshot(self) -> list[TableSize]: - """ - Takes a snapshot of tables in the specified catalog and database. - Return None if the table cannot be found anymore. - - Returns: - list[Table]: A list of Table objects representing the snapshot of tables. - """ - return self._snapshot(partial(self._try_load), partial(self._crawl)) - def _safe_get_table_size(self, table_full_name: str) -> int | None: logger.debug(f"Evaluating {table_full_name} table size.") try: diff --git a/src/databricks/labs/ucx/hive_metastore/tables.py b/src/databricks/labs/ucx/hive_metastore/tables.py index 7404b547fa..ef69a05fa4 100644 --- a/src/databricks/labs/ucx/hive_metastore/tables.py +++ b/src/databricks/labs/ucx/hive_metastore/tables.py @@ -357,15 +357,6 @@ def _all_databases(self) -> list[str]: return [row[0] for row in self._fetch("SHOW DATABASES")] return self._include_database - def snapshot(self) -> list[Table]: - """ - Takes a snapshot of tables in the specified catalog and database. - - Returns: - list[Table]: A list of Table objects representing the snapshot of tables. - """ - return self._snapshot(partial(self._try_load), partial(self._crawl)) - def load_one(self, schema_name: str, table_name: str) -> Table | None: query = f"SELECT * FROM {escape_sql_identifier(self.full_name)} WHERE database='{schema_name}' AND name='{table_name}' LIMIT 1" for row in self._fetch(query): @@ -386,7 +377,7 @@ def parse_database_props(tbl_props: str) -> dict: # Convert key-value pairs to dictionary return dict(key_value_pairs) - def _try_load(self) -> Iterable[Table]: + def _try_fetch(self) -> Iterable[Table]: """Tries to load table information from the database or throws TABLE_OR_VIEW_NOT_FOUND error""" for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield Table(*row) diff --git a/src/databricks/labs/ucx/hive_metastore/udfs.py b/src/databricks/labs/ucx/hive_metastore/udfs.py index 0837463194..6bfd173449 100644 --- a/src/databricks/labs/ucx/hive_metastore/udfs.py +++ b/src/databricks/labs/ucx/hive_metastore/udfs.py @@ -50,16 +50,7 @@ def _all_databases(self) -> list[str]: return [row[0] for row in self._fetch("SHOW DATABASES")] return self._include_database - def snapshot(self) -> list[Udf]: - """ - Takes a snapshot of tables in the specified catalog and database. - - Returns: - list[Udf]: A list of Udf objects representing the snapshot of tables. - """ - return self._snapshot(self._try_load, self._crawl) - - def _try_load(self) -> Iterable[Udf]: + def _try_fetch(self) -> Iterable[Udf]: """Tries to load udf information from the database or throws TABLE_OR_VIEW_NOT_FOUND error""" for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield Udf(*row) diff --git a/src/databricks/labs/ucx/recon/migration_recon.py b/src/databricks/labs/ucx/recon/migration_recon.py index 164a2c4ddc..114f306381 100644 --- a/src/databricks/labs/ucx/recon/migration_recon.py +++ b/src/databricks/labs/ucx/recon/migration_recon.py @@ -52,9 +52,6 @@ def __init__( self._data_comparator = data_comparator self._default_threshold = default_threshold - def snapshot(self) -> Iterable[ReconResult]: - return self._snapshot(self._try_fetch, self._crawl) - def _crawl(self) -> Iterable[ReconResult]: self._migration_status_refresher.reset() migration_index = self._migration_status_refresher.index() diff --git a/src/databricks/labs/ucx/workspace_access/generic.py b/src/databricks/labs/ucx/workspace_access/generic.py index 228f9498e1..35125bf612 100644 --- a/src/databricks/labs/ucx/workspace_access/generic.py +++ b/src/databricks/labs/ucx/workspace_access/generic.py @@ -360,9 +360,6 @@ def _crawl(self) -> Iterable[WorkspaceObjectInfo]: language=raw.get("language", None), ) - def snapshot(self) -> Iterable[WorkspaceObjectInfo]: - return self._snapshot(self._try_fetch, self._crawl) - def _try_fetch(self) -> Iterable[WorkspaceObjectInfo]: for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): yield WorkspaceObjectInfo( diff --git a/src/databricks/labs/ucx/workspace_access/groups.py b/src/databricks/labs/ucx/workspace_access/groups.py index b157c1ec86..308888d60f 100644 --- a/src/databricks/labs/ucx/workspace_access/groups.py +++ b/src/databricks/labs/ucx/workspace_access/groups.py @@ -431,12 +431,6 @@ def __init__( # pylint: disable=too-many-arguments self._external_id_match = external_id_match self._verify_timeout = verify_timeout - def snapshot(self) -> list[MigratedGroup]: - return self._snapshot(self._fetcher, self._crawler) - - def has_groups(self) -> bool: - return len(self.snapshot()) > 0 - def rename_groups(self): account_groups_in_workspace = self._account_groups_in_workspace() workspace_groups_in_workspace = self._workspace_groups_in_workspace() @@ -579,7 +573,7 @@ def reflect_account_groups_on_workspace(self): raise ManyError(errors) def get_migration_state(self) -> MigrationState: - return MigrationState(self.snapshot()) + return MigrationState(list(self.snapshot())) def delete_original_workspace_groups(self): account_groups_in_workspace = self._account_groups_in_workspace() @@ -626,7 +620,7 @@ def delete_original_workspace_groups(self): # Step 3: Confirm that enumeration no longer returns the deleted groups. self._wait_for_deleted_workspace_groups(deleted_groups) - def _fetcher(self) -> Iterable[MigratedGroup]: + def _try_fetch(self) -> Iterable[MigratedGroup]: state = [] for row in self._backend.fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"): state.append(MigratedGroup(*row)) @@ -646,7 +640,7 @@ def _fetcher(self) -> Iterable[MigratedGroup]: ) return new_state - def _crawler(self) -> Iterable[MigratedGroup]: + def _crawl(self) -> Iterable[MigratedGroup]: workspace_groups_in_workspace = self._workspace_groups_in_workspace() account_groups_in_account = self._account_groups_in_account() strategy = self._get_strategy(workspace_groups_in_workspace, account_groups_in_account) diff --git a/src/databricks/labs/ucx/workspace_access/manager.py b/src/databricks/labs/ucx/workspace_access/manager.py index 105a90a27e..1d870546b2 100644 --- a/src/databricks/labs/ucx/workspace_access/manager.py +++ b/src/databricks/labs/ucx/workspace_access/manager.py @@ -1,16 +1,11 @@ -import json import logging -from collections.abc import Callable, Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator from itertools import groupby from databricks.labs.blueprint.parallel import ManyError, Threads from databricks.labs.lsql.backends import SqlBackend -from databricks.labs.ucx.framework.crawlers import ( - CrawlerBase, - Dataclass, - DataclassInstance, -) +from databricks.labs.ucx.framework.crawlers import CrawlerBase from databricks.labs.ucx.framework.utils import escape_sql_identifier from databricks.labs.ucx.workspace_access.base import AclSupport, Permissions from databricks.labs.ucx.workspace_access.groups import MigrationState @@ -25,8 +20,7 @@ def __init__(self, backend: SqlBackend, inventory_database: str, crawlers: list[ super().__init__(backend, "hive_metastore", inventory_database, "permissions", Permissions) self._acl_support = crawlers - def inventorize_permissions(self): - # TODO: rename into snapshot() + def _crawl(self) -> Iterable[Permissions]: logger.debug("Crawling permissions") crawler_tasks = list(self._get_crawler_tasks()) logger.info(f"Starting to crawl permissions. Total tasks: {len(crawler_tasks)}") @@ -38,18 +32,17 @@ def inventorize_permissions(self): continue logger.error(f"Error while crawling permissions: {error}") acute_errors.append(error) - if len(acute_errors) > 0: + if acute_errors: raise ManyError(acute_errors) logger.info(f"Total crawled permissions: {len(items)}") - self._save(items) - logger.info(f"Saved {len(items)} to {self.full_name}") + return items def apply_group_permissions(self, migration_state: MigrationState) -> bool: # list shall be sorted prior to using group by if len(migration_state) == 0: logger.info("No valid groups selected, nothing to do.") return True - items = sorted(self.load_all(), key=lambda i: i.object_type) + items = sorted(self.snapshot(), key=lambda i: i.object_type) logger.info( f"Applying the permissions to account groups. " f"Total groups to apply permissions: {len(migration_state)}. " @@ -93,7 +86,7 @@ def apply_group_permissions(self, migration_state: MigrationState) -> bool: return True def verify_group_permissions(self) -> bool: - items = sorted(self.load_all(), key=lambda i: i.object_type) + items = sorted(self.snapshot(), key=lambda i: i.object_type) logger.info(f"Total permissions found: {len(items)}") verifier_tasks: list[Callable[..., bool]] = [] appliers = self.object_type_support() @@ -130,34 +123,9 @@ def object_type_support(self) -> dict[str, AclSupport]: appliers[object_type] = support return appliers - def cleanup(self): - logger.info(f"Cleaning up inventory table {self.full_name}") - self._exec(f"DROP TABLE IF EXISTS {escape_sql_identifier(self.full_name)}") - logger.info("Inventory table cleanup complete") - - def _save(self, items: Sequence[Permissions]): - # keep in mind, that object_type and object_id are not primary keys. - self._append_records(items) # TODO: update instead of append - logger.info("Successfully saved the items to inventory table") - - def load_all(self) -> list[Permissions]: - logger.info(f"Loading inventory table {self.full_name}") - if list(self._fetch(f"SELECT COUNT(*) as cnt FROM {self.full_name}"))[0][0] == 0: # noqa: RUF015 - msg = ( - f"table {self.full_name} is empty for fetching permission info. " - f"Please ensure assessment job is run successfully and permissions populated" - ) - raise RuntimeError(msg) - return [ - Permissions(object_id, object_type, raw) - for object_id, object_type, raw in self._fetch(f"SELECT object_id, object_type, raw FROM {self.full_name}") - ] - - def load_all_for(self, object_type: str, object_id: str, klass: Dataclass) -> Iterable[DataclassInstance]: - for perm in self.load_all(): - if object_type == perm.object_type and object_id.lower() == perm.object_id.lower(): - raw = json.loads(perm.raw) - yield klass(**raw) + def _try_fetch(self) -> Iterable[Permissions]: + for row in self._fetch(f"SELECT object_id, object_type, raw FROM {escape_sql_identifier(self.full_name)}"): + yield Permissions(*row) def _get_crawler_tasks(self) -> Iterator[Callable[..., Permissions | None]]: for support in self._acl_support: diff --git a/tests/integration/workspace_access/test_permissions_manager.py b/tests/integration/workspace_access/test_permissions_manager.py index 2bd3934457..9868923f48 100644 --- a/tests/integration/workspace_access/test_permissions_manager.py +++ b/tests/integration/workspace_access/test_permissions_manager.py @@ -1,16 +1,39 @@ -from databricks.labs.ucx.workspace_access.base import Permissions +from collections.abc import Callable, Iterable + +from databricks.labs.ucx.workspace_access.base import Permissions, AclSupport +from databricks.labs.ucx.workspace_access.groups import MigrationState from databricks.labs.ucx.workspace_access.manager import PermissionManager -def test_permissions_save_and_load(ws, sql_backend, inventory_schema, env_or_skip): - permission_manager = PermissionManager(sql_backend, inventory_schema, []) +def test_permissions_snapshot(ws, sql_backend, inventory_schema): + class StubbedCrawler(AclSupport): + def get_crawler_tasks(self) -> Iterable[Callable[..., Permissions | None]]: + yield lambda: Permissions(object_id="abc", object_type="bcd", raw="def") + yield lambda: Permissions(object_id="efg", object_type="fgh", raw="ghi") - saved = [ + def get_apply_task(self, item: Permissions, migration_state: MigrationState) -> Callable[[], None] | None: ... + def get_verify_task(self, item: Permissions) -> Callable[[], bool] | None: ... + def object_types(self) -> set[str]: + return {"bcd", "fgh"} + + permission_manager = PermissionManager(sql_backend, inventory_schema, [StubbedCrawler()]) + snapshot = list(permission_manager.snapshot()) + # Snapshotting is multithreaded, meaning the order of results is non-deterministic. + snapshot.sort(key=lambda x: x.object_id) + + expected = [ Permissions(object_id="abc", object_type="bcd", raw="def"), Permissions(object_id="efg", object_type="fgh", raw="ghi"), ] + assert snapshot == expected - permission_manager._save(saved) # pylint: disable=protected-access - loaded = permission_manager.load_all() - - assert saved == loaded + saved = [ + Permissions(*row) + for row in sql_backend.fetch( + f"SELECT object_id, object_type, raw\n" + f"FROM {inventory_schema}.permissions\n" + f"ORDER BY object_id\n" + f"LIMIT {len(expected)+1}" + ) + ] + assert saved == expected diff --git a/tests/integration/workspace_access/test_workflows.py b/tests/integration/workspace_access/test_workflows.py index aea31cca39..2a56d86bc1 100644 --- a/tests/integration/workspace_access/test_workflows.py +++ b/tests/integration/workspace_access/test_workflows.py @@ -42,7 +42,7 @@ def test_running_real_migrate_groups_job( ] installation_ctx.workspace_installation.run() - installation_ctx.permission_manager.inventorize_permissions() + installation_ctx.permission_manager.snapshot() installation_ctx.deployed_workflows.run_workflow("migrate-groups") @@ -91,7 +91,7 @@ def test_running_real_validate_groups_permissions_job( f"secrets:{secret_scope}", ] installation_ctx.workspace_installation.run() - installation_ctx.permission_manager.inventorize_permissions() + installation_ctx.permission_manager.snapshot() # assert the job does not throw any exception installation_ctx.deployed_workflows.run_workflow("validate-groups-permissions") @@ -115,7 +115,7 @@ def test_running_real_validate_groups_permissions_job_fails( installation_ctx.__dict__['include_group_names'] = [ws_group_a.display_name] installation_ctx.__dict__['include_object_permissions'] = [f'cluster-policies:{cluster_policy.policy_id}'] installation_ctx.workspace_installation.run() - installation_ctx.permission_manager.inventorize_permissions() + installation_ctx.permission_manager.snapshot() # remove permission so the validation fails ws.permissions.set( diff --git a/tests/unit/assessment/test_workflows.py b/tests/unit/assessment/test_workflows.py index f0c4480772..83a5778c61 100644 --- a/tests/unit/assessment/test_workflows.py +++ b/tests/unit/assessment/test_workflows.py @@ -26,7 +26,7 @@ def test_runtime_crawl_grants(run_workflow): def test_runtime_crawl_permissions(run_workflow): ctx = run_workflow(Assessment.crawl_permissions) - assert "DROP TABLE IF EXISTS `hive_metastore`.`ucx`.`permissions`" in ctx.sql_backend.queries + assert "TRUNCATE TABLE `hive_metastore`.`ucx`.`permissions`" in ctx.sql_backend.queries def test_runtime_crawl_groups(run_workflow): diff --git a/tests/unit/framework/test_crawlers.py b/tests/unit/framework/test_crawlers.py index 5d75c228d8..d6f5c7778b 100644 --- a/tests/unit/framework/test_crawlers.py +++ b/tests/unit/framework/test_crawlers.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable from dataclasses import dataclass import pytest @@ -5,9 +6,7 @@ from databricks.labs.lsql.backends import MockBackend from databricks.sdk.errors import NotFound -from databricks.labs.ucx.framework.crawlers import CrawlerBase - -# pylint: disable=protected-access +from databricks.labs.ucx.framework.crawlers import CrawlerBase, Result, ResultFn @dataclass @@ -29,21 +28,44 @@ class Bar: third: float +class _CrawlerFixture(CrawlerBase[Result]): + def __init__( + self, + backend: MockBackend, + catalog: str, + schema: str, + table: str, + klass: type[Result], + *, + fetcher: ResultFn = lambda: [], + loader: ResultFn = lambda: [], + ): + super().__init__(backend, catalog, schema, table, klass) + self._fetcher = fetcher + self._loader = loader + + def _try_fetch(self) -> Iterable[Result]: + return self._fetcher() + + def _crawl(self) -> Iterable[Result]: + return self._loader() + + def test_invalid(): with pytest.raises(ValueError): - CrawlerBase(MockBackend(), "a.a.a", "b", "c", Bar) + _CrawlerFixture(MockBackend(), "a.a.a", "b", "c", Bar) def test_full_name(): - cb = CrawlerBase(MockBackend(), "a", "b", "c", Bar) + cb = _CrawlerFixture(MockBackend(), "a", "b", "c", Bar) assert cb.full_name == "a.b.c" def test_snapshot_appends_to_existing_table(): mock_backend = MockBackend() - cb = CrawlerBase[Baz](mock_backend, "a", "b", "c", Baz) + cb = _CrawlerFixture[Baz](mock_backend, "a", "b", "c", Baz, loader=lambda: [Baz(first="first")]) - result = cb._snapshot(fetcher=lambda: [], loader=lambda: [Baz(first="first")]) + result = cb.snapshot() assert [Baz(first="first")] == result assert [Row(first="first", second=None)] == mock_backend.rows_written_for("a.b.c", "append") @@ -51,13 +73,16 @@ def test_snapshot_appends_to_existing_table(): def test_snapshot_appends_to_new_table(): mock_backend = MockBackend() - cb = CrawlerBase[Foo](mock_backend, "a", "b", "c", Foo) def fetcher(): msg = ".. TABLE_OR_VIEW_NOT_FOUND .." raise NotFound(msg) - result = cb._snapshot(fetcher=fetcher, loader=lambda: [Foo(first="first", second=True)]) + cb = _CrawlerFixture[Foo]( + mock_backend, "a", "b", "c", Foo, fetcher=fetcher, loader=lambda: [Foo(first="first", second=True)] + ) + + result = cb.snapshot() assert [Foo(first="first", second=True)] == result assert [Row(first="first", second=True)] == mock_backend.rows_written_for("a.b.c", "append") @@ -65,11 +90,12 @@ def fetcher(): def test_snapshot_wrong_error(): sql_backend = MockBackend() - cb = CrawlerBase(sql_backend, "a", "b", "c", Bar) def fetcher(): msg = "always fails" raise ValueError(msg) + cb = _CrawlerFixture[Bar](sql_backend, "a", "b", "c", Bar, fetcher=fetcher) + with pytest.raises(ValueError): - cb._snapshot(fetcher=fetcher, loader=lambda: [Foo(first="first", second=True)]) + cb.snapshot() diff --git a/tests/unit/workspace_access/test_manager.py b/tests/unit/workspace_access/test_manager.py index f78757a594..ca06bcdd30 100644 --- a/tests/unit/workspace_access/test_manager.py +++ b/tests/unit/workspace_access/test_manager.py @@ -19,77 +19,45 @@ def mock_backend(): return MockBackend() -def test_inventory_table_manager_init(mock_backend): +def test_inventory_permission_manager_init(mock_backend): permission_manager = PermissionManager(mock_backend, "test_database", []) assert permission_manager.full_name == "hive_metastore.test_database.permissions" -def test_cleanup(mock_backend): - permission_manager = PermissionManager(mock_backend, "test_database", []) - - permission_manager.cleanup() - - assert mock_backend.queries[0] == "DROP TABLE IF EXISTS `hive_metastore`.`test_database`.`permissions`" - - -def test_save(mock_backend): - permission_manager = PermissionManager(mock_backend, "test_database", []) - - permission_manager._save([Permissions("object1", "clusters", "test acl")]) # pylint: disable=protected-access - - assert [Row(object_id="object1", object_type="clusters", raw="test acl")] == mock_backend.rows_written_for( - "hive_metastore.test_database.permissions", "append" - ) - - _PermissionsRow = Row.factory(["object_id", "object_type", "raw"]) -def test_load_all(): +def test_snapshot_fetch() -> None: + """Verify that the snapshot will load existing data from the inventory.""" sql_backend = MockBackend( rows={ - "SELECT object_id": [ + "SELECT object_id, object_type, raw FROM ": [ _PermissionsRow("object1", "clusters", "test acl"), ], - "SELECT COUNT": [Row(cnt=12)], } ) permission_manager = PermissionManager(sql_backend, "test_database", []) - output = permission_manager.load_all() + output = list(permission_manager.snapshot()) assert output[0] == Permissions(object_id="object1", object_type="clusters", raw="test acl") -def test_load_all_no_rows_present(): - sql_backend = MockBackend( - rows={ - "SELECT object_id": [ - _PermissionsRow("object1", "clusters", "test acl"), - ], - "SELECT COUNT": [Row(cnt=0)], - } - ) - - permission_manager = PermissionManager(sql_backend, "test_database", []) - - with pytest.raises(RuntimeError): - permission_manager.load_all() - - -def test_manager_inventorize(mock_backend, mocker): +def test_snapshot_crawl_fallback(mocker) -> None: + """Verify that the snapshot will first attempt to load the (empty) inventory and then crawl.""" some_crawler = mocker.Mock() some_crawler.get_crawler_tasks = lambda: [lambda: None, lambda: Permissions("a", "b", "c"), lambda: None] - permission_manager = PermissionManager(mock_backend, "test_database", [some_crawler]) + sql_backend = MockBackend(rows={"SELECT object_id, object_type, raw FROM ": []}) + permission_manager = PermissionManager(sql_backend, "test_database", [some_crawler]) - permission_manager.inventorize_permissions() + permission_manager.snapshot() - assert [Row(object_id="a", object_type="b", raw="c")] == mock_backend.rows_written_for( + assert [Row(object_id="a", object_type="b", raw="c")] == sql_backend.rows_written_for( "hive_metastore.test_database.permissions", "append" ) -def test_manager_inventorize_ignore_error(mock_backend, mocker): +def test_manager_snapshot_crawl_ignore_disabled_features(mock_backend, mocker): def raise_error(): raise DatabricksError( "Model serving is not enabled for your shard. " @@ -101,14 +69,14 @@ def raise_error(): some_crawler.get_crawler_tasks = lambda: [lambda: None, lambda: Permissions("a", "b", "c"), raise_error] permission_manager = PermissionManager(mock_backend, "test_database", [some_crawler]) - permission_manager.inventorize_permissions() + permission_manager.snapshot() assert [Row(object_id="a", object_type="b", raw="c")] == mock_backend.rows_written_for( "hive_metastore.test_database.permissions", "append" ) -def test_manager_inventorize_fail_with_error(mock_backend, mocker): +def test_manager_snapshot_crawl_with_error(mock_backend, mocker): def raise_error(): raise DatabricksError( "Fail the job", @@ -123,7 +91,7 @@ def raise_error_no_code(): permission_manager = PermissionManager(mock_backend, "test_database", [some_crawler]) with pytest.raises(ManyError) as expected_err: - permission_manager.inventorize_permissions() + permission_manager.snapshot() assert len(expected_err.value.errs) == 2 @@ -235,7 +203,6 @@ def test_manager_verify(): ), ), ], - "SELECT COUNT": [Row(cnt=12)], } ) @@ -276,7 +243,6 @@ def test_manager_verify_not_supported_type(): ), ), ], - "SELECT COUNT": [Row(cnt=12)], } ) @@ -311,7 +277,6 @@ def test_manager_verify_no_tasks(): ), ), ], - "SELECT COUNT": [Row(cnt=12)], } )