Skip to content

Commit

Permalink
Addressed review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
FastLee committed Jan 30, 2025
1 parent c55eb9f commit 1ca97ab
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 23 deletions.
5 changes: 4 additions & 1 deletion src/databricks/labs/ucx/recon/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass


Expand Down Expand Up @@ -82,7 +83,9 @@ def as_dict(self):

class TableMetadataRetriever(ABC):
@abstractmethod
def get_metadata(self, entity: TableIdentifier, /, case_sensitive: bool = False) -> TableMetadata:
def get_metadata(
self, entity: TableIdentifier, *, column_name_transformer: Callable[[str], str] = str
) -> TableMetadata:
"""
Get metadata for a given table
"""
Expand Down
18 changes: 9 additions & 9 deletions src/databricks/labs/ucx/recon/metadata_retriever.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Iterator
from collections.abc import Iterator, Callable

from databricks.labs.lsql.backends import SqlBackend
from databricks.labs.lsql.core import Row
Expand All @@ -10,7 +10,9 @@ class DatabricksTableMetadataRetriever(TableMetadataRetriever):
def __init__(self, sql_backend: SqlBackend):
self._sql_backend = sql_backend

def get_metadata(self, entity: TableIdentifier, /, case_sensitive: bool = False) -> TableMetadata:
def get_metadata(
self, entity: TableIdentifier, *, column_name_transformer: Callable[[str], str] = str
) -> TableMetadata:
"""
This method retrieves the metadata for a given table. It takes a TableIdentifier object as input,
which represents the table for which the metadata is to be retrieved.
Expand All @@ -24,13 +26,11 @@ def get_metadata(self, entity: TableIdentifier, /, case_sensitive: bool = False)
# Partition information are typically prefixed with a # symbol,
# so any column name starting with # is excluded from the final set of column metadata.
# The column metadata objects are sorted by column name to ensure a consistent order.
columns = {
ColumnMetadata(
str(row["col_name"] if case_sensitive else str(row["col_name"]).lower()), str(row["data_type"])
)
for row in query_result
if not str(row["col_name"]).startswith("#")
}
columns = set()
for row in query_result:
if str(row["col_name"]).startswith("#"):
continue
columns.add(ColumnMetadata(column_name_transformer(str(row["col_name"])), str(row["data_type"])))
return TableMetadata(entity, sorted(columns, key=lambda x: x.name))

@classmethod
Expand Down
25 changes: 12 additions & 13 deletions src/databricks/labs/ucx/recon/schema_comparator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import replace
from collections.abc import Callable

from .base import (
SchemaComparator,
Expand All @@ -15,8 +15,10 @@ def __init__(self, metadata_retriever: TableMetadataRetriever, *, case_sensitive
self._metadata_retriever = metadata_retriever
self._case_sensitive = case_sensitive

def _column_name_transformer(self, column_name: str) -> str:
return column_name if self._case_sensitive else column_name.lower()
def _column_name_transformer(self) -> Callable[[str], str]:
if self._case_sensitive:
return lambda _: _
return str.lower

def compare_schema(self, source: TableIdentifier, target: TableIdentifier) -> SchemaComparisonResult:
"""
Expand All @@ -32,8 +34,12 @@ def compare_schema(self, source: TableIdentifier, target: TableIdentifier) -> Sc
return SchemaComparisonResult(is_matching, comparison_result)

def _eval_schema_diffs(self, source: TableIdentifier, target: TableIdentifier) -> list[SchemaComparisonEntry]:
source_metadata = self._metadata_retriever.get_metadata(source, self._case_sensitive)
target_metadata = self._metadata_retriever.get_metadata(target, self._case_sensitive)
source_metadata = self._metadata_retriever.get_metadata(
source, column_name_transformer=self._column_name_transformer()
)
target_metadata = self._metadata_retriever.get_metadata(
target, column_name_transformer=self._column_name_transformer()
)
# Combine the sets of column names for both the source and target tables
# to create a set of all unique column names from both tables.
source_column_names = {column.name for column in source_metadata.columns}
Expand All @@ -51,16 +57,9 @@ def _eval_schema_diffs(self, source: TableIdentifier, target: TableIdentifier) -
def _build_comparison_result_entry(
source_col: ColumnMetadata | None,
target_col: ColumnMetadata | None,
/,
case_sensitive: bool = False,
) -> SchemaComparisonEntry:
if source_col and target_col:
if case_sensitive:
is_matching = source_col == target_col
else:
is_matching = replace(source_col, name=source_col.name.lower()) == replace(
target_col, name=target_col.name.lower()
)
is_matching = source_col == target_col
notes = None
else:
is_matching = False
Expand Down

0 comments on commit 1ca97ab

Please sign in to comment.