diff --git a/snowflake_utils/models/table.py b/snowflake_utils/models/table.py index be2841d..4e7d9a2 100644 --- a/snowflake_utils/models/table.py +++ b/snowflake_utils/models/table.py @@ -344,32 +344,34 @@ def single_column_update( f"UPDATE {self.fqn} SET {target_column.name} = {new_column.name};" ) - def _current_tags(self, level: TagLevel) -> list[tuple[str, str, str]]: - with connect() as connection: - cursor = connection.cursor() - cursor.execute( - f"""select lower(column_name) as column_name, lower(tag_name) as tag_name, tag_value + def _current_tags( + self, level: TagLevel, cursor: SnowflakeCursor + ) -> list[tuple[str, str, str]]: + cursor.execute( + f"""select lower(column_name) as column_name, lower(tag_name) as tag_name, tag_value from table(information_schema.tag_references_all_columns('{self.fqn}', 'table')) where lower(level) = '{level.value}' """ - ) - return cursor.fetchall() + ) + return cursor.fetchall() - def current_column_tags(self) -> dict[str, dict[str, str]]: + def current_column_tags(self, cursor: SnowflakeCursor) -> dict[str, dict[str, str]]: tags = defaultdict(dict) - for column_name, tag_name, tag_value in self._current_tags(TagLevel.COLUMN): + for column_name, tag_name, tag_value in self._current_tags( + TagLevel.COLUMN, cursor + ): tags[column_name][tag_name] = tag_value return tags - def current_table_tags(self) -> dict[str, str]: + def current_table_tags(self, cursor: SnowflakeCursor) -> dict[str, str]: return { tag_name.casefold(): tag_value - for _, tag_name, tag_value in self._current_tags(TagLevel.TABLE) + for _, tag_name, tag_value in self._current_tags(TagLevel.TABLE, cursor) } def sync_tags_table(self, cursor: SnowflakeCursor) -> None: - tags = self.current_table_tags() + tags = self.current_table_tags(cursor=cursor) desired_tags = {k.casefold(): v for k, v in self.table_structure.tags.items()} for tag_name in desired_tags: if tag_name not in tags: @@ -393,7 +395,7 @@ def sync_tags(self, cursor: SnowflakeCursor) -> None: self.sync_tags_columns(cursor) def sync_tags_columns(self, cursor: SnowflakeCursor) -> None: - tags = self.current_column_tags() + tags = self.current_column_tags(cursor) existing_tags = { f"{column}.{tag_name}.{tags[column][tag_name]}".casefold(): ( column, diff --git a/tests/test_models.py b/tests/test_models.py index c3ad6aa..7369be9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -172,7 +172,10 @@ def test_merge() -> None: primary_keys=["id"], ) - assert dict(test_table.current_column_tags()) == {"id": {"pii": "personal"}} + with connect() as conn, conn.cursor() as cursor: + assert dict(test_table.current_column_tags(cursor)) == { + "id": {"pii": "personal"} + } @pytest.mark.snowflake_vcr @@ -232,55 +235,64 @@ def test_copy_with_tags() -> None: assert result[0][1] == "LOADED" - assert dict(test_table.current_column_tags()) == {"id": {"pii": "personal"}} - assert dict(test_table.current_table_tags()) == {"pii": "foo"} + with connect() as conn, conn.cursor() as cursor: + assert dict(test_table.current_column_tags(cursor)) == { + "id": {"pii": "personal"} + } + assert dict(test_table.current_table_tags(cursor)) == {"pii": "foo"} test_table.drop() @pytest.mark.snowflake_vcr def test_copy_custom() -> None: - result = test_table.copy_custom( - column_definitions={ - "id": "$1:id", - "name": "$1:name", - "last_name": "$1:last_name", - }, - path=path, - file_format=parquet_file_format, - storage_integration=storage_integration, - full_refresh=True, - sync_tags=True, - ) - assert result[0][1] == "LOADED" - assert dict(test_table.current_column_tags()) == {"id": {"pii": "personal"}} - assert dict(test_table.current_table_tags()) == {"pii": "foo"} + with connect() as conn, conn.cursor() as cursor: + result = test_table.copy_custom( + column_definitions={ + "id": "$1:id", + "name": "$1:name", + "last_name": "$1:last_name", + }, + path=path, + file_format=parquet_file_format, + storage_integration=storage_integration, + full_refresh=True, + sync_tags=True, + ) + assert result[0][1] == "LOADED" + assert dict(test_table.current_column_tags(cursor)) == { + "id": {"pii": "personal"} + } + assert dict(test_table.current_table_tags(cursor)) == {"pii": "foo"} test_table.drop() @pytest.mark.snowflake_vcr def test_merge_custom() -> None: - column_definitions = { - "id": "$1:id", - "name": "$1:name", - "last_name": "$1:last_name", - } - test_table.copy_custom( - column_definitions=column_definitions, - path=path, - file_format=parquet_file_format, - storage_integration=storage_integration, - full_refresh=True, - sync_tags=True, - ) - test_table.merge_custom( - column_definitions=column_definitions, - path=path, - file_format=parquet_file_format, - storage_integration=storage_integration, - primary_keys=["id"], - ) + with connect() as conn, conn.cursor() as cursor: + column_definitions = { + "id": "$1:id", + "name": "$1:name", + "last_name": "$1:last_name", + } + test_table.copy_custom( + column_definitions=column_definitions, + path=path, + file_format=parquet_file_format, + storage_integration=storage_integration, + full_refresh=True, + sync_tags=True, + ) + test_table.merge_custom( + column_definitions=column_definitions, + path=path, + file_format=parquet_file_format, + storage_integration=storage_integration, + primary_keys=["id"], + ) - assert dict(test_table.current_column_tags()) == {"id": {"pii": "personal"}} - assert dict(test_table.current_table_tags()) == {"pii": "foo"} + assert dict(test_table.current_column_tags(cursor)) == { + "id": {"pii": "personal"} + } + assert dict(test_table.current_table_tags(cursor)) == {"pii": "foo"} test_table.drop()