Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Snowflake/Snowpark Compare #354

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 106 additions & 81 deletions datacompy/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union, cast

Expand All @@ -31,8 +32,9 @@

try:
import snowflake.snowpark as sp
from snowflake.connector.errors import ProgrammingError
from snowflake.snowpark import Window
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.exceptions import SnowparkClientException
from snowflake.snowpark.functions import (
abs,
col,
Expand All @@ -45,6 +47,7 @@
trim,
when,
)

except ImportError:
pass # for non-snowflake users
from datacompy.base import BaseCompare
Expand Down Expand Up @@ -288,8 +291,12 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None:
LOG.debug("Duplicate rows found, deduping by order of remaining fields")
# setting internal index
LOG.info("Adding internal index to dataframes")
df1 = df1.withColumn("__index", monotonically_increasing_id())
df2 = df2.withColumn("__index", monotonically_increasing_id())
df1 = df1.withColumn(
"__index", monotonically_increasing_id()
).cache_result()
df2 = df2.withColumn(
"__index", monotonically_increasing_id()
).cache_result()

# Create order column for uniqueness of match
order_column = temp_column_name(df1, df2)
Expand All @@ -305,11 +312,6 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None:
).drop("__index")
temp_join_columns.append(order_column)

# drop index
LOG.info("Dropping internal index")
df1 = df1.drop("__index")
df2 = df2.drop("__index")

if ignore_spaces:
for column in self.join_columns:
if "string" in next(
Expand Down Expand Up @@ -400,23 +402,15 @@ def _intersect_compare(self, ignore_spaces: bool) -> None:
"""Run the comparison on the intersect dataframe.

This loops through all columns that are shared between df1 and df2, and
creates a column column_match which is True for matches, False
otherwise.
creates a column column_match which is True for matches, False otherwise.
Finally calculates and stores the compare metrics for matching column pairs.
"""
LOG.debug("Comparing intersection")
max_diff: float
null_diff: int
row_cnt = self.intersect_rows.count()
for column in self.intersect_columns():
if column in self.join_columns:
match_cnt = row_cnt
col_match = ""
max_diff = 0
null_diff = 0
else:
col_1 = column + "_" + self.df1_name
col_2 = column + "_" + self.df2_name
col_match = column + "_MATCH"
for col in self.intersect_columns():
if col not in self.join_columns:
col_1 = col + "_" + self.df1_name
col_2 = col + "_" + self.df2_name
col_match = col + "_MATCH"
self.intersect_rows = columns_equal(
self.intersect_rows,
col_1,
Expand All @@ -426,46 +420,87 @@ def _intersect_compare(self, ignore_spaces: bool) -> None:
self.abs_tol,
ignore_spaces,
)
match_cnt = (
self.intersect_rows.select(col_match)
.where(col(col_match) == True) # noqa: E712
.count()
)
max_diff = calculate_max_diff(
self.intersect_rows,
col_1,
col_2,
row_cnt = self.intersect_rows.count()

with ThreadPoolExecutor() as executor:
futures = []
for column in self.intersect_columns():
future = executor.submit(
self._calculate_column_compare_stats, column, row_cnt
)
null_diff = calculate_null_diff(self.intersect_rows, col_1, col_2)
futures.append(future)
for future in as_completed(futures):
if future.exception():
raise future.exception()

if row_cnt > 0:
match_rate = float(match_cnt) / row_cnt
else:
match_rate = 0
LOG.info(f"{column}: {match_cnt} / {row_cnt} ({match_rate:.2%}) match")

col1_dtype, _ = _get_column_dtypes(self.df1, column, column)
col2_dtype, _ = _get_column_dtypes(self.df2, column, column)

self.column_stats.append(
{
"column": column,
"match_column": col_match,
"match_cnt": match_cnt,
"unequal_cnt": row_cnt - match_cnt,
"dtype1": str(col1_dtype),
"dtype2": str(col2_dtype),
"all_match": all(
(
col1_dtype == col2_dtype,
row_cnt == match_cnt,
)
),
"max_diff": max_diff,
"null_diff": null_diff,
}
def _calculate_column_compare_stats(self, column: str, row_cnt: int) -> None:
"""Populate the column stats for all intersecting column pairs.

Calculates compare stats by intersecting column pairs. For the non-trivial case
where intersecting columns are not join columns, a match count, max difference,
and null difference must be calculated.
"""
if column in self.join_columns:
match_cnt = row_cnt
col_match = ""
max_diff = 0
null_diff = 0
else:
col_1 = column + "_" + self.df1_name
col_2 = column + "_" + self.df2_name
col_match = column + "_MATCH"

match_cnt = (
self.intersect_rows.select(col_match)
.where(col(col_match) == True) # noqa: E712
.count(block=False)
)

max_diff = calculate_max_diff(
self.intersect_rows,
col_1,
col_2,
)
null_diff = calculate_null_diff(self.intersect_rows, col_1, col_2)

match_cnt = match_cnt.result()
try:
max_diff = max_diff.result()[0][0]
except (SnowparkClientException, ProgrammingError):
max_diff = 0
try:
null_diff = null_diff.result()
except (SnowparkClientException, ProgrammingError):
null_diff = 0

if row_cnt > 0:
match_rate = float(match_cnt) / row_cnt
else:
match_rate = 0
LOG.info(f"{column}: {match_cnt} / {row_cnt} ({match_rate:.2%}) match")

col1_dtype, _ = _get_column_dtypes(self.df1, column, column)
col2_dtype, _ = _get_column_dtypes(self.df2, column, column)

self.column_stats.append(
{
"column": column,
"match_column": col_match,
"match_cnt": match_cnt,
"unequal_cnt": row_cnt - match_cnt,
"dtype1": str(col1_dtype),
"dtype2": str(col2_dtype),
"all_match": all(
(
col1_dtype == col2_dtype,
row_cnt == match_cnt,
)
),
"max_diff": max_diff,
"null_diff": null_diff,
}
)

def all_columns_match(self) -> bool:
"""Whether the columns all match in the dataframes.

Expand Down Expand Up @@ -994,23 +1029,16 @@ def calculate_max_diff(dataframe: "sp.DataFrame", col_1: str, col_2: str) -> flo
max diff
"""
# Attempting to coalesce maximum diff for non-numeric results in error, if error return 0 max diff.
try:
diff = dataframe.select(
(col(col_1).astype("float") - col(col_2).astype("float")).alias("diff")
)
abs_diff = diff.select(abs(col("diff")).alias("abs_diff"))
max_diff: float = (
abs_diff.where(is_null(col("abs_diff")) == False) # noqa: E712
.agg({"abs_diff": "max"})
.collect()[0][0]
)
except SnowparkSQLException:
return None

if pd.isna(max_diff) or pd.isnull(max_diff) or max_diff is None:
return 0
else:
return max_diff
diff = dataframe.select(
(col(col_1).astype("float") - col(col_2).astype("float")).alias("diff")
)
abs_diff = diff.select(abs(col("diff")).alias("abs_diff"))
max_diff: float = (
abs_diff.where(is_null(col("abs_diff")) == False) # noqa: E712
.agg({"abs_diff": "max"})
.collect(block=False)
)
return max_diff


def calculate_null_diff(dataframe: "sp.DataFrame", col_1: str, col_2: str) -> int:
Expand Down Expand Up @@ -1047,12 +1075,9 @@ def calculate_null_diff(dataframe: "sp.DataFrame", col_1: str, col_2: str) -> in
null_diff = nulls_df.where(
((col("col_1_null") == False) & (col("col_2_null") == True)) # noqa: E712
| ((col("col_1_null") == True) & (col("col_2_null") == False)) # noqa: E712
).count()
).count(block=False)

if pd.isna(null_diff) or pd.isnull(null_diff) or null_diff is None:
return 0
else:
return null_diff
return null_diff


def _generate_id_within_group(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,7 @@ def test_calculate_max_diff(snowpark_session, column, expected):
)
MAX_DIFF_DF = snowpark_session.createDataFrame(pdf)
assert np.isclose(
calculate_max_diff(MAX_DIFF_DF, "BASE", column),
calculate_max_diff(MAX_DIFF_DF, "BASE", column).result()[0][0],
expected,
)

Expand Down
Loading