Skip to content

Commit

Permalink
Refactor: CrawlerBase and PermissionsManager snapshotting pattern (
Browse files Browse the repository at this point in the history
…#2402)

## Changes

This PR:

- Refactors the existing permissions crawler so that it uses the same
fetch/crawl/snapshot pattern as the rest of the crawlers.
- Removes the permissions crawler's bespoke `.cleanup()` implementation
in favour of the `.reset()` method already available on the base class.
- Some dead code was removed on the permissions crawler:
`.load_for_all()`
- Specifies `.snapshot()` as an interface (on `CrawlerBase`) that all
crawlers need to support.

### Linked issues

Progresses #2074, by laying the ground for crawlers to support
refreshing.

Note that this PR and #2392 have a behaviour conflict that will trigger
a failure in `test_runtime_crawl_permissions` if either is merged:
resolving this is trivial.

### Functionality

- modified existing workflow: `assessment`

### Tests

- [X] updated unit tests
- [X] updated integration tests
  • Loading branch information
asnare authored Sep 6, 2024
1 parent cf2011d commit 92a8f47
Show file tree
Hide file tree
Showing 23 changed files with 146 additions and 211 deletions.
3 changes: 0 additions & 3 deletions src/databricks/labs/ucx/assessment/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "azure_service_principals", AzureServicePrincipalInfo)
self._ws = ws

def snapshot(self) -> Iterable[AzureServicePrincipalInfo]:
return self._snapshot(self._try_fetch, self._crawl)

def _try_fetch(self) -> Iterable[AzureServicePrincipalInfo]:
for row in self._fetch(f"SELECT * FROM {self._catalog}.{self._schema}.{self._table}"):
yield AzureServicePrincipalInfo(*row)
Expand Down
6 changes: 0 additions & 6 deletions src/databricks/labs/ucx/assessment/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ def _assess_clusters(self, all_clusters):
cluster_info.failures = json.dumps(failures)
yield cluster_info

def snapshot(self) -> Iterable[ClusterInfo]:
return self._snapshot(self._try_fetch, self._crawl)

def _try_fetch(self) -> Iterable[ClusterInfo]:
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield ClusterInfo(*row)
Expand Down Expand Up @@ -229,9 +226,6 @@ def _assess_policies(self, all_policices) -> Iterable[PolicyInfo]:
policy_info.failures = json.dumps(failures)
yield policy_info

def snapshot(self) -> Iterable[PolicyInfo]:
return self._snapshot(self._try_fetch, self._crawl)

def _try_fetch(self) -> Iterable[PolicyInfo]:
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield PolicyInfo(*row)
3 changes: 0 additions & 3 deletions src/databricks/labs/ucx/assessment/init_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ def _assess_global_init_scripts(self, all_global_init_scripts):
global_init_script_info.success = 0
yield global_init_script_info

def snapshot(self) -> Iterable[GlobalInitScriptInfo]:
return self._snapshot(self._try_fetch, self._crawl)

def _try_fetch(self) -> Iterable[GlobalInitScriptInfo]:
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield GlobalInitScriptInfo(*row)
6 changes: 0 additions & 6 deletions src/databricks/labs/ucx/assessment/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ def _prepare(all_jobs) -> tuple[dict[int, set[str]], dict[int, JobInfo]]:
)
return job_assessment, job_details

def snapshot(self) -> Iterable[JobInfo]:
return self._snapshot(self._try_fetch, self._crawl)

def _try_fetch(self) -> Iterable[JobInfo]:
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield JobInfo(*row)
Expand Down Expand Up @@ -166,9 +163,6 @@ def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema: str, num_days_h
self._ws = ws
self._num_days_history = num_days_history

def snapshot(self) -> Iterable[SubmitRunInfo]:
return self._snapshot(self._try_fetch, self._crawl)

@staticmethod
def _dt_to_ms(date_time: datetime):
return int(date_time.timestamp() * 1000)
Expand Down
3 changes: 0 additions & 3 deletions src/databricks/labs/ucx/assessment/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ def _pipeline_clusters(self, clusters, failures):
if cluster.init_scripts:
failures.extend(self._check_cluster_init_script(cluster.init_scripts, "pipeline cluster"))

def snapshot(self) -> Iterable[PipelineInfo]:
return self._snapshot(self._try_fetch, self._crawl)

def _try_fetch(self) -> Iterable[PipelineInfo]:
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield PipelineInfo(*row)
4 changes: 2 additions & 2 deletions src/databricks/labs/ucx/assessment/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def crawl_permissions(self, ctx: RuntimeContext):
This is the first step for the _group migration_ process, which is continued in the `migrate-groups` workflow.
This step includes preparing Legacy Table ACLs for local group migration."""
permission_manager = ctx.permission_manager
permission_manager.cleanup()
permission_manager.inventorize_permissions()
permission_manager.reset()
permission_manager.snapshot()

@job_task
def crawl_groups(self, ctx: RuntimeContext):
Expand Down
24 changes: 22 additions & 2 deletions src/databricks/labs/ucx/framework/crawlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
from typing import ClassVar, Generic, Protocol, TypeVar

Expand All @@ -19,7 +20,7 @@ class DataclassInstance(Protocol):
ResultFn = Callable[[], Iterable[Result]]


class CrawlerBase(Generic[Result]):
class CrawlerBase(ABC, Generic[Result]):
def __init__(self, backend: SqlBackend, catalog: str, schema: str, table: str, klass: type[Result]):
"""
Initializes a CrawlerBase instance.
Expand Down Expand Up @@ -90,6 +91,25 @@ def _try_valid(cls, name: str | None):
return None
return cls._valid(name)

def snapshot(self) -> Iterable[Result]:
return self._snapshot(self._try_fetch, self._crawl)

@abstractmethod
def _try_fetch(self) -> Iterable[Result]:
"""Fetch existing data that has (previously) been crawled by this crawler.
Returns:
Iterable[Result]: The data that has already been crawled.
"""

@abstractmethod
def _crawl(self) -> Iterable[Result]:
"""Perform the (potentially slow) crawling necessary to capture the current state of the environment.
Returns:
Iterable[Result]: Records that capture the results of crawling the environment.
"""

def _snapshot(self, fetcher: ResultFn, loader: ResultFn) -> list[Result]:
"""
Tries to load dataset of records with `fetcher` function, otherwise automatically creates
Expand All @@ -105,7 +125,7 @@ def _snapshot(self, fetcher: ResultFn, loader: ResultFn) -> list[Result]:
re-raised.
Returns:
list[any]: A list of data records, either fetched or loaded.
list[Result]: A list of data records, either fetched or loaded.
"""
logger.debug(f"[{self.full_name}] fetching {self._table} inventory")
try:
Expand Down
8 changes: 4 additions & 4 deletions src/databricks/labs/ucx/hive_metastore/grants.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,13 @@ def __init__(self, tc: TablesCrawler, udf: UdfsCrawler, include_databases: list[

def snapshot(self) -> Iterable[Grant]:
try:
return self._snapshot(partial(self._try_load), partial(self._crawl))
return super().snapshot()
except Exception as e: # pylint: disable=broad-exception-caught
log_fn = logger.warning if CLUSTER_WITHOUT_ACL_FRAGMENT in repr(e) else logger.error
log_fn(f"Couldn't fetch grants snapshot: {e}")
return []

def _try_load(self):
def _try_fetch(self):
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield Grant(*row)

Expand Down Expand Up @@ -585,7 +585,7 @@ def __init__(
self._compute_locations = cluster_locations

def get_interactive_cluster_grants(self) -> list[Grant]:
tables = self._tables_crawler.snapshot()
tables = list(self._tables_crawler.snapshot())
mounts = list(self._mounts_crawler.snapshot())
grants: set[Grant] = set()

Expand Down Expand Up @@ -794,7 +794,7 @@ def __init__(

def migrate_acls(self, *, target_catalog: str | None = None, hms_fed: bool = False) -> None:
workspace_name = self._workspace_info.current()
tables = self._table_crawler.snapshot()
tables = list(self._table_crawler.snapshot())
if not tables:
logger.info("No tables found to acl")
return
Expand Down
28 changes: 13 additions & 15 deletions src/databricks/labs/ucx/hive_metastore/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _add_jdbc_location(self, external_locations, location, table):
if not dupe:
external_locations.append(ExternalLocation(jdbc_location, 1))

def _external_location_list(self) -> Iterable[ExternalLocation]:
def _crawl(self) -> Iterable[ExternalLocation]:
tables_identifier = escape_sql_identifier(f"{self._catalog}.{self._schema}.tables")
tables = list(
self._backend.fetch(
Expand All @@ -211,9 +211,6 @@ def _external_location_list(self) -> Iterable[ExternalLocation]:
mounts = Mounts(self._backend, self._ws, self._schema).snapshot()
return self._external_locations(list(tables), list(mounts))

def snapshot(self) -> Iterable[ExternalLocation]:
return self._snapshot(self._try_fetch, self._external_location_list)

def _try_fetch(self) -> Iterable[ExternalLocation]:
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield ExternalLocation(*row)
Expand Down Expand Up @@ -320,15 +317,12 @@ def _deduplicate_mounts(self, mounts: list) -> list:
deduplicated_mounts.append(obj)
return deduplicated_mounts

def _list_mounts(self) -> Iterable[Mount]:
def _crawl(self) -> Iterable[Mount]:
mounts = []
for mount_point, source, _ in self._dbutils.fs.mounts():
mounts.append(Mount(mount_point, source))
return self._deduplicate_mounts(mounts)

def snapshot(self) -> Iterable[Mount]:
return self._snapshot(self._try_fetch, self._list_mounts)

def _try_fetch(self) -> Iterable[Mount]:
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield Mount(*row)
Expand Down Expand Up @@ -366,21 +360,25 @@ def __init__(
self._fiter_paths = irrelevant_patterns

def snapshot(self) -> list[Table]:
updated_records = self._crawl()
self._overwrite_records(updated_records)
return updated_records

def _crawl(self) -> list[Table]:
logger.debug(f"[{self.full_name}] fetching {self._table} inventory")
cached_results = []
try:
cached_results = list(self._try_load())
cached_results = list(self._try_fetch())
except NotFound:
pass
table_paths = self._get_tables_paths_from_assessment(cached_results)
logger.debug(f"[{self.full_name}] crawling new batch for {self._table}")
loaded_records = list(self._crawl(table_paths))
if len(cached_results) > 0:
loaded_records = loaded_records + cached_results
self._overwrite_records(loaded_records)
loaded_records = list(self._crawl_tables(table_paths))
if cached_results:
loaded_records = [*loaded_records, *cached_results]
return loaded_records

def _try_load(self) -> Iterable[Table]:
def _try_fetch(self) -> Iterable[Table]:
"""Tries to load table information from the database or throws TABLE_OR_VIEW_NOT_FOUND error"""
for row in self._fetch(
f"SELECT * FROM {escape_sql_identifier(self.full_name)} WHERE NOT STARTSWITH(database, '{self.TABLE_IN_MOUNT_DB}')"
Expand All @@ -399,7 +397,7 @@ def _overwrite_records(self, items: Sequence[Table]):
logger.debug(f"[{self.full_name}] found {len(items)} new records for {self._table}")
self._backend.save_table(self.full_name, items, Table, mode="overwrite")

def _crawl(self, table_paths_from_assessment: dict[str, str]):
def _crawl_tables(self, table_paths_from_assessment: dict[str, str]):
all_mounts = self._mounts_crawler.snapshot()
all_tables = []
for mount in all_mounts:
Expand Down
4 changes: 2 additions & 2 deletions src/databricks/labs/ucx/hive_metastore/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def __init__(
self._recon_tolerance_percent = recon_tolerance_percent

def current_tables(self, tables: TablesCrawler, workspace_name: str, catalog_name: str):
tables_snapshot = tables.snapshot()
if len(tables_snapshot) == 0:
tables_snapshot = list(tables.snapshot())
if not tables_snapshot:
msg = "No tables found. Please run: databricks labs ucx ensure-assessment-run"
raise ValueError(msg)
for table in tables_snapshot:
Expand Down
3 changes: 0 additions & 3 deletions src/databricks/labs/ucx/hive_metastore/migration_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema, table_crawler:
self._ws = ws
self._table_crawler = table_crawler

def snapshot(self) -> Iterable[MigrationStatus]:
return self._snapshot(self._try_fetch, self._crawl)

def index(self) -> MigrationIndex:
return MigrationIndex(list(self.snapshot()))

Expand Down
13 changes: 1 addition & 12 deletions src/databricks/labs/ucx/hive_metastore/table_size.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from collections.abc import Iterable
from dataclasses import dataclass
from functools import partial

from databricks.labs.lsql.backends import SqlBackend

Expand Down Expand Up @@ -54,21 +53,11 @@ def _crawl(self) -> Iterable[TableSize]:
catalog=table.catalog, database=table.database, name=table.name, size_in_bytes=size_in_bytes
)

def _try_load(self) -> Iterable[TableSize]:
def _try_fetch(self) -> Iterable[TableSize]:
"""Tries to load table information from the database or throws TABLE_OR_VIEW_NOT_FOUND error"""
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield TableSize(*row)

def snapshot(self) -> list[TableSize]:
"""
Takes a snapshot of tables in the specified catalog and database.
Return None if the table cannot be found anymore.
Returns:
list[Table]: A list of Table objects representing the snapshot of tables.
"""
return self._snapshot(partial(self._try_load), partial(self._crawl))

def _safe_get_table_size(self, table_full_name: str) -> int | None:
logger.debug(f"Evaluating {table_full_name} table size.")
try:
Expand Down
11 changes: 1 addition & 10 deletions src/databricks/labs/ucx/hive_metastore/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,15 +357,6 @@ def _all_databases(self) -> list[str]:
return [row[0] for row in self._fetch("SHOW DATABASES")]
return self._include_database

def snapshot(self) -> list[Table]:
"""
Takes a snapshot of tables in the specified catalog and database.
Returns:
list[Table]: A list of Table objects representing the snapshot of tables.
"""
return self._snapshot(partial(self._try_load), partial(self._crawl))

def load_one(self, schema_name: str, table_name: str) -> Table | None:
query = f"SELECT * FROM {escape_sql_identifier(self.full_name)} WHERE database='{schema_name}' AND name='{table_name}' LIMIT 1"
for row in self._fetch(query):
Expand All @@ -386,7 +377,7 @@ def parse_database_props(tbl_props: str) -> dict:
# Convert key-value pairs to dictionary
return dict(key_value_pairs)

def _try_load(self) -> Iterable[Table]:
def _try_fetch(self) -> Iterable[Table]:
"""Tries to load table information from the database or throws TABLE_OR_VIEW_NOT_FOUND error"""
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield Table(*row)
Expand Down
11 changes: 1 addition & 10 deletions src/databricks/labs/ucx/hive_metastore/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,7 @@ def _all_databases(self) -> list[str]:
return [row[0] for row in self._fetch("SHOW DATABASES")]
return self._include_database

def snapshot(self) -> list[Udf]:
"""
Takes a snapshot of tables in the specified catalog and database.
Returns:
list[Udf]: A list of Udf objects representing the snapshot of tables.
"""
return self._snapshot(self._try_load, self._crawl)

def _try_load(self) -> Iterable[Udf]:
def _try_fetch(self) -> Iterable[Udf]:
"""Tries to load udf information from the database or throws TABLE_OR_VIEW_NOT_FOUND error"""
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield Udf(*row)
Expand Down
3 changes: 0 additions & 3 deletions src/databricks/labs/ucx/recon/migration_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def __init__(
self._data_comparator = data_comparator
self._default_threshold = default_threshold

def snapshot(self) -> Iterable[ReconResult]:
return self._snapshot(self._try_fetch, self._crawl)

def _crawl(self) -> Iterable[ReconResult]:
self._migration_status_refresher.reset()
migration_index = self._migration_status_refresher.index()
Expand Down
3 changes: 0 additions & 3 deletions src/databricks/labs/ucx/workspace_access/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,6 @@ def _crawl(self) -> Iterable[WorkspaceObjectInfo]:
language=raw.get("language", None),
)

def snapshot(self) -> Iterable[WorkspaceObjectInfo]:
return self._snapshot(self._try_fetch, self._crawl)

def _try_fetch(self) -> Iterable[WorkspaceObjectInfo]:
for row in self._fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
yield WorkspaceObjectInfo(
Expand Down
12 changes: 3 additions & 9 deletions src/databricks/labs/ucx/workspace_access/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,6 @@ def __init__( # pylint: disable=too-many-arguments
self._external_id_match = external_id_match
self._verify_timeout = verify_timeout

def snapshot(self) -> list[MigratedGroup]:
return self._snapshot(self._fetcher, self._crawler)

def has_groups(self) -> bool:
return len(self.snapshot()) > 0

def rename_groups(self):
account_groups_in_workspace = self._account_groups_in_workspace()
workspace_groups_in_workspace = self._workspace_groups_in_workspace()
Expand Down Expand Up @@ -579,7 +573,7 @@ def reflect_account_groups_on_workspace(self):
raise ManyError(errors)

def get_migration_state(self) -> MigrationState:
return MigrationState(self.snapshot())
return MigrationState(list(self.snapshot()))

def delete_original_workspace_groups(self):
account_groups_in_workspace = self._account_groups_in_workspace()
Expand Down Expand Up @@ -626,7 +620,7 @@ def delete_original_workspace_groups(self):
# Step 3: Confirm that enumeration no longer returns the deleted groups.
self._wait_for_deleted_workspace_groups(deleted_groups)

def _fetcher(self) -> Iterable[MigratedGroup]:
def _try_fetch(self) -> Iterable[MigratedGroup]:
state = []
for row in self._backend.fetch(f"SELECT * FROM {escape_sql_identifier(self.full_name)}"):
state.append(MigratedGroup(*row))
Expand All @@ -646,7 +640,7 @@ def _fetcher(self) -> Iterable[MigratedGroup]:
)
return new_state

def _crawler(self) -> Iterable[MigratedGroup]:
def _crawl(self) -> Iterable[MigratedGroup]:
workspace_groups_in_workspace = self._workspace_groups_in_workspace()
account_groups_in_account = self._account_groups_in_account()
strategy = self._get_strategy(workspace_groups_in_workspace, account_groups_in_account)
Expand Down
Loading

0 comments on commit 92a8f47

Please sign in to comment.