Skip to content

Commit

Permalink
feat: reuse connection for tag syncing
Browse files Browse the repository at this point in the history
  • Loading branch information
Eli Yarson committed Dec 5, 2024
1 parent ba857b8 commit fe24013
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 53 deletions.
28 changes: 15 additions & 13 deletions snowflake_utils/models/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,32 +344,34 @@ def single_column_update(
f"UPDATE {self.fqn} SET {target_column.name} = {new_column.name};"
)

def _current_tags(self, level: TagLevel) -> list[tuple[str, str, str]]:
with connect() as connection:
cursor = connection.cursor()
cursor.execute(
f"""select lower(column_name) as column_name, lower(tag_name) as tag_name, tag_value
def _current_tags(
self, level: TagLevel, cursor: SnowflakeCursor
) -> list[tuple[str, str, str]]:
cursor.execute(
f"""select lower(column_name) as column_name, lower(tag_name) as tag_name, tag_value
from table(information_schema.tag_references_all_columns('{self.fqn}', 'table'))
where lower(level) = '{level.value}'
"""
)
return cursor.fetchall()
)
return cursor.fetchall()

def current_column_tags(self) -> dict[str, dict[str, str]]:
def current_column_tags(self, cursor: SnowflakeCursor) -> dict[str, dict[str, str]]:
tags = defaultdict(dict)

for column_name, tag_name, tag_value in self._current_tags(TagLevel.COLUMN):
for column_name, tag_name, tag_value in self._current_tags(
TagLevel.COLUMN, cursor
):
tags[column_name][tag_name] = tag_value
return tags

def current_table_tags(self) -> dict[str, str]:
def current_table_tags(self, cursor: SnowflakeCursor) -> dict[str, str]:
return {
tag_name.casefold(): tag_value
for _, tag_name, tag_value in self._current_tags(TagLevel.TABLE)
for _, tag_name, tag_value in self._current_tags(TagLevel.TABLE, cursor)
}

def sync_tags_table(self, cursor: SnowflakeCursor) -> None:
tags = self.current_table_tags()
tags = self.current_table_tags(cursor=cursor)
desired_tags = {k.casefold(): v for k, v in self.table_structure.tags.items()}
for tag_name in desired_tags:
if tag_name not in tags:
Expand All @@ -393,7 +395,7 @@ def sync_tags(self, cursor: SnowflakeCursor) -> None:
self.sync_tags_columns(cursor)

def sync_tags_columns(self, cursor: SnowflakeCursor) -> None:
tags = self.current_column_tags()
tags = self.current_column_tags(cursor)
existing_tags = {
f"{column}.{tag_name}.{tags[column][tag_name]}".casefold(): (
column,
Expand Down
92 changes: 52 additions & 40 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def test_merge() -> None:
primary_keys=["id"],
)

assert dict(test_table.current_column_tags()) == {"id": {"pii": "personal"}}
with connect() as conn, conn.cursor() as cursor:
assert dict(test_table.current_column_tags(cursor)) == {
"id": {"pii": "personal"}
}


@pytest.mark.snowflake_vcr
Expand Down Expand Up @@ -232,55 +235,64 @@ def test_copy_with_tags() -> None:

assert result[0][1] == "LOADED"

assert dict(test_table.current_column_tags()) == {"id": {"pii": "personal"}}
assert dict(test_table.current_table_tags()) == {"pii": "foo"}
with connect() as conn, conn.cursor() as cursor:
assert dict(test_table.current_column_tags(cursor)) == {
"id": {"pii": "personal"}
}
assert dict(test_table.current_table_tags(cursor)) == {"pii": "foo"}

test_table.drop()


@pytest.mark.snowflake_vcr
def test_copy_custom() -> None:
result = test_table.copy_custom(
column_definitions={
"id": "$1:id",
"name": "$1:name",
"last_name": "$1:last_name",
},
path=path,
file_format=parquet_file_format,
storage_integration=storage_integration,
full_refresh=True,
sync_tags=True,
)
assert result[0][1] == "LOADED"
assert dict(test_table.current_column_tags()) == {"id": {"pii": "personal"}}
assert dict(test_table.current_table_tags()) == {"pii": "foo"}
with connect() as conn, conn.cursor() as cursor:
result = test_table.copy_custom(
column_definitions={
"id": "$1:id",
"name": "$1:name",
"last_name": "$1:last_name",
},
path=path,
file_format=parquet_file_format,
storage_integration=storage_integration,
full_refresh=True,
sync_tags=True,
)
assert result[0][1] == "LOADED"
assert dict(test_table.current_column_tags(cursor)) == {
"id": {"pii": "personal"}
}
assert dict(test_table.current_table_tags(cursor)) == {"pii": "foo"}
test_table.drop()


@pytest.mark.snowflake_vcr
def test_merge_custom() -> None:
column_definitions = {
"id": "$1:id",
"name": "$1:name",
"last_name": "$1:last_name",
}
test_table.copy_custom(
column_definitions=column_definitions,
path=path,
file_format=parquet_file_format,
storage_integration=storage_integration,
full_refresh=True,
sync_tags=True,
)
test_table.merge_custom(
column_definitions=column_definitions,
path=path,
file_format=parquet_file_format,
storage_integration=storage_integration,
primary_keys=["id"],
)
with connect() as conn, conn.cursor() as cursor:
column_definitions = {
"id": "$1:id",
"name": "$1:name",
"last_name": "$1:last_name",
}
test_table.copy_custom(
column_definitions=column_definitions,
path=path,
file_format=parquet_file_format,
storage_integration=storage_integration,
full_refresh=True,
sync_tags=True,
)
test_table.merge_custom(
column_definitions=column_definitions,
path=path,
file_format=parquet_file_format,
storage_integration=storage_integration,
primary_keys=["id"],
)

assert dict(test_table.current_column_tags()) == {"id": {"pii": "personal"}}
assert dict(test_table.current_table_tags()) == {"pii": "foo"}
assert dict(test_table.current_column_tags(cursor)) == {
"id": {"pii": "personal"}
}
assert dict(test_table.current_table_tags(cursor)) == {"pii": "foo"}
test_table.drop()

0 comments on commit fe24013

Please sign in to comment.