Skip to content

Commit

Permalink
adding all_rows_overlap to fugue
Browse files Browse the repository at this point in the history
  • Loading branch information
fdosani committed Nov 10, 2023
1 parent b9fcbf9 commit e4a2cc1
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 4 deletions.
1 change: 1 addition & 0 deletions datacompy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from datacompy.core import *
from datacompy.fugue import (
all_columns_match,
all_rows_overlap,
intersect_columns,
is_match,
report,
Expand Down
100 changes: 98 additions & 2 deletions datacompy/fugue.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,102 @@ def is_match(
return all(matches)


def all_rows_overlap(
df1: AnyDataFrame,
df2: AnyDataFrame,
join_columns: Union[str, List[str]],
abs_tol: float = 0,
rel_tol: float = 0,
df1_name: str = "df1",
df2_name: str = "df2",
ignore_spaces: bool = False,
ignore_case: bool = False,
cast_column_names_lower: bool = True,
parallelism: Optional[int] = None,
strict_schema: bool = False,
) -> bool:
"""Check if the rows are all present in both dataframes
Parameters
----------
df1 : ``AnyDataFrame``
First dataframe to check
df2 : ``AnyDataFrame``
Second dataframe to check
join_columns : list or str, optional
Column(s) to join dataframes on. If a string is passed in, that one
column will be used.
abs_tol : float, optional
Absolute tolerance between two values.
rel_tol : float, optional
Relative tolerance between two values.
df1_name : str, optional
A string name for the first dataframe. This allows the reporting to
print out an actual name instead of "df1", and allows human users to
more easily track the dataframes.
df2_name : str, optional
A string name for the second dataframe
ignore_spaces : bool, optional
Flag to strip whitespace (including newlines) from string columns (including any join
columns)
ignore_case : bool, optional
Flag to ignore the case of string columns
cast_column_names_lower: bool, optional
Boolean indicator that controls of column names will be cast into lower case
parallelism: int, optional
An integer representing the amount of parallelism. Entering a value for this
will force to use of Fugue over just vanilla Pandas
strict_schema: bool, optional
The schema must match exactly if set to ``True``. This includes the names and types. Allows for a fast fail.
Returns
-------
bool
True if all rows in df1 are in df2 and vice versa (based on
existence for join option)
"""
if (
isinstance(df1, pd.DataFrame)
and isinstance(df2, pd.DataFrame)
and parallelism is None # user did not specify parallelism
and fa.get_current_parallelism() == 1 # currently on a local execution engine
):
comp = Compare(
df1=df1,
df2=df2,
join_columns=join_columns,
abs_tol=abs_tol,
rel_tol=rel_tol,
df1_name=df1_name,
df2_name=df2_name,
ignore_spaces=ignore_spaces,
ignore_case=ignore_case,
cast_column_names_lower=cast_column_names_lower,
)
return comp.all_rows_overlap()

try:
overlap = _distributed_compare(
df1=df1,
df2=df2,
join_columns=join_columns,
return_obj_func=lambda comp: comp.all_rows_overlap(),
abs_tol=abs_tol,
rel_tol=rel_tol,
df1_name=df1_name,
df2_name=df2_name,
ignore_spaces=ignore_spaces,
ignore_case=ignore_case,
cast_column_names_lower=cast_column_names_lower,
parallelism=parallelism,
strict_schema=strict_schema,
)
except _StrictSchemaError:
return False

return all(overlap)


def report(
df1: AnyDataFrame,
df2: AnyDataFrame,
Expand All @@ -210,7 +306,7 @@ def report(
column_count=10,
html_file=None,
parallelism: Optional[int] = None,
) -> None:
) -> str:
"""Returns a string representation of a report. The representation can
then be printed or saved to a file.
Expand Down Expand Up @@ -460,7 +556,7 @@ def _distributed_compare(
parallelism: Optional[int] = None,
strict_schema: bool = False,
) -> List[Any]:
"""Compare the data distributedly using the core Compare class
"""Compare the data distributively using the core Compare class
Both df1 and df2 should be dataframes containing all of the join_columns,
with unique column names. Differences between values are compared to
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,13 @@ dev = [
"datacompy[build]",
]

[isort]
[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
line_length = 88
profile = "black"

[edgetest.envs.core]
python_version = "3.9"
Expand Down
80 changes: 79 additions & 1 deletion tests/test_fugue.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from datacompy import (
Compare,
all_columns_match,
all_rows_overlap,
intersect_columns,
is_match,
report,
Expand All @@ -50,7 +51,14 @@ def ref_df():
df1_copy = df1.copy()
df2 = df1.copy().drop(columns=["c"])
df3 = df1.copy().drop(columns=["a", "b"])
return [df1, df1_copy, df2, df3]
df4 = pd.DataFrame(
dict(
a=np.random.randint(1, 12, 100), # shift the join col
b=np.random.rand(100),
c=np.random.choice(["aaa", "b_c", "csd"], 100),
)
)
return [df1, df1_copy, df2, df3, df4]


@pytest.fixture
Expand Down Expand Up @@ -590,3 +598,73 @@ def test_all_columns_match_duckdb(ref_df):
assert all_columns_match(df1, df3) is False
assert all_columns_match(df1_copy, df1) is True
assert all_columns_match(df3, df2) is False


def test_all_rows_overlap_native(
ref_df,
shuffle_df,
):
# defaults to Compare class
assert all_rows_overlap(ref_df[0], ref_df[0].copy(), join_columns="a")
assert all_rows_overlap(ref_df[0], shuffle_df, join_columns="a")
assert not all_rows_overlap(ref_df[0], ref_df[4], join_columns="a")
# Fugue
assert all_rows_overlap(ref_df[0], shuffle_df, join_columns="a", parallelism=2)
assert not all_rows_overlap(ref_df[0], ref_df[4], join_columns="a", parallelism=2)


def test_all_rows_overlap_spark(
spark_session,
ref_df,
shuffle_df,
):
ref_df[0].iteritems = ref_df[0].items # pandas 2 compatibility
ref_df[4].iteritems = ref_df[4].items # pandas 2 compatibility
shuffle_df.iteritems = shuffle_df.items # pandas 2 compatibility
rdf = spark_session.createDataFrame(ref_df[0])
rdf_copy = spark_session.createDataFrame(ref_df[0])
rdf4 = spark_session.createDataFrame(ref_df[4])
sdf = spark_session.createDataFrame(shuffle_df)

assert all_rows_overlap(rdf, rdf_copy, join_columns="a")
assert all_rows_overlap(rdf, sdf, join_columns="a")
assert not all_rows_overlap(rdf, rdf4, join_columns="a")
assert all_rows_overlap(
spark_session.sql("SELECT 'a' AS a, 'b' AS b"),
spark_session.sql("SELECT 'a' AS a, 'b' AS b"),
join_columns="a",
)


def test_all_rows_overlap_polars(
ref_df,
shuffle_df,
):
rdf = pl.from_pandas(ref_df[0])
rdf_copy = pl.from_pandas(ref_df[0].copy())
rdf4 = pl.from_pandas(ref_df[4])
sdf = pl.from_pandas(shuffle_df)

assert all_rows_overlap(rdf, rdf_copy, join_columns="a")
assert all_rows_overlap(rdf, sdf, join_columns="a")
assert not all_rows_overlap(rdf, rdf4, join_columns="a")


def test_all_rows_overlap_duckdb(
ref_df,
shuffle_df,
):
with duckdb.connect():
rdf = duckdb.from_df(ref_df[0])
rdf_copy = duckdb.from_df(ref_df[0].copy())
rdf4 = duckdb.from_df(ref_df[4])
sdf = duckdb.from_df(shuffle_df)

assert all_rows_overlap(rdf, rdf_copy, join_columns="a")
assert all_rows_overlap(rdf, sdf, join_columns="a")
assert not all_rows_overlap(rdf, rdf4, join_columns="a")
assert all_rows_overlap(
duckdb.sql("SELECT 'a' AS a, 'b' AS b"),
duckdb.sql("SELECT 'a' AS a, 'b' AS b"),
join_columns="a",
)

0 comments on commit e4a2cc1

Please sign in to comment.