Skip to content

Commit

Permalink
Merge pull request #233 from nijanthanvijayakumar/issue-49-remove-exi…
Browse files Browse the repository at this point in the history
…sts-forall

Issue 49 remove exists forall
  • Loading branch information
SemyonSinchenko authored Jul 11, 2024
2 parents f07755a + 7ef60e4 commit b67fc98
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 82 deletions.
2 changes: 0 additions & 2 deletions quinn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
approx_equal,
array_choice,
business_days_between,
exists,
forall,
is_false,
is_falsy,
is_not_in,
Expand Down
40 changes: 0 additions & 40 deletions quinn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable
from numbers import Number

from pyspark.sql import Column
Expand Down Expand Up @@ -84,45 +83,6 @@ def remove_non_word_characters(col: Column) -> Column:
return F.regexp_replace(col, "[^\\w\\s]+", "")


def exists(f: Callable[[Any], bool]) -> udf:
"""Create a user-defined function.
It takes a list expressed as a column of type ``ArrayType(AnyType)`` as an argument and returns a boolean value indicating
whether any element in the list is true according to the argument ``f`` of the ``exists()`` function.
:param f: Callable function - A callable function that takes an element of
type Any and returns a boolean value.
:return: A user-defined function that takes
a list expressed as a column of type ArrayType(AnyType) as an argument and
returns a boolean value indicating whether any element in the list is true
according to the argument ``f`` of the ``exists()`` function.
:rtype: UserDefinedFunction
"""

def temp_udf(list_: list) -> bool:
return any(map(f, list_))

return F.udf(temp_udf, BooleanType())


def forall(f: Callable[[Any], bool]) -> udf:
"""The **forall** function allows for mapping a given boolean function to a list of arguments and return a single boolean value.
It does this by creating a Spark UDF which takes in a list of arguments, applying the given boolean function to
each element of the list and returning a single boolean value if all the elements pass through the given boolean function.
:param f: A callable function ``f`` which takes in any type and returns a boolean
:return: A spark UDF which accepts a list of arguments and returns True if all
elements pass through the given boolean function, False otherwise.
:rtype: UserDefinedFunction
"""

def temp_udf(list_: list) -> bool:
return all(map(f, list_))

return F.udf(temp_udf, BooleanType())


def multi_equals(value: Any) -> udf: # noqa: ANN401
"""Create a user-defined function that checks if all the given columns have the designated value.
Expand Down
40 changes: 0 additions & 40 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,46 +97,6 @@ def test_anti_trim():
chispa.assert_column_equality(actual_df, "words_anti_trimmed", "expected")


def test_exists():
df = spark.createDataFrame(
[
([1, 2, 3], False),
([4, 5, 6], True),
([10, 11, 12], True),
],
StructType(
[
StructField("nums", ArrayType(IntegerType(), True), True),
StructField("expected", BooleanType(), True),
]
),
)
actual_df = df.withColumn(
"any_num_greater_than_5", quinn.exists(lambda n: n > 5)(F.col("nums"))
)
chispa.assert_column_equality(actual_df, "any_num_greater_than_5", "expected")


def test_forall():
df = spark.createDataFrame(
[
([1, 2, 3], False),
([4, 5, 6], True),
([10, 11, 12], True),
],
StructType(
[
StructField("nums", ArrayType(IntegerType(), True), True),
StructField("expected", BooleanType(), True),
]
),
)
actual_df = df.withColumn(
"all_nums_greater_than_3", quinn.forall(lambda n: n > 3)(F.col("nums"))
)
chispa.assert_column_equality(actual_df, "all_nums_greater_than_3", "expected")


def test_multi_equals():
df = quinn.create_df(
spark,
Expand Down

0 comments on commit b67fc98

Please sign in to comment.