diff --git a/quinn/__init__.py b/quinn/__init__.py index 119cf8a7..469cac82 100644 --- a/quinn/__init__.py +++ b/quinn/__init__.py @@ -34,8 +34,6 @@ approx_equal, array_choice, business_days_between, - exists, - forall, is_false, is_falsy, is_not_in, diff --git a/quinn/functions.py b/quinn/functions.py index f802ae55..12464c81 100644 --- a/quinn/functions.py +++ b/quinn/functions.py @@ -16,7 +16,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from collections.abc import Callable from numbers import Number from pyspark.sql import Column @@ -84,45 +83,6 @@ def remove_non_word_characters(col: Column) -> Column: return F.regexp_replace(col, "[^\\w\\s]+", "") -def exists(f: Callable[[Any], bool]) -> udf: - """Create a user-defined function. - - It takes a list expressed as a column of type ``ArrayType(AnyType)`` as an argument and returns a boolean value indicating - whether any element in the list is true according to the argument ``f`` of the ``exists()`` function. - - :param f: Callable function - A callable function that takes an element of - type Any and returns a boolean value. - :return: A user-defined function that takes - a list expressed as a column of type ArrayType(AnyType) as an argument and - returns a boolean value indicating whether any element in the list is true - according to the argument ``f`` of the ``exists()`` function. - :rtype: UserDefinedFunction - """ - - def temp_udf(list_: list) -> bool: - return any(map(f, list_)) - - return F.udf(temp_udf, BooleanType()) - - -def forall(f: Callable[[Any], bool]) -> udf: - """The **forall** function allows for mapping a given boolean function to a list of arguments and return a single boolean value. - - It does this by creating a Spark UDF which takes in a list of arguments, applying the given boolean function to - each element of the list and returning a single boolean value if all the elements pass through the given boolean function. - - :param f: A callable function ``f`` which takes in any type and returns a boolean - :return: A spark UDF which accepts a list of arguments and returns True if all - elements pass through the given boolean function, False otherwise. - :rtype: UserDefinedFunction - """ - - def temp_udf(list_: list) -> bool: - return all(map(f, list_)) - - return F.udf(temp_udf, BooleanType()) - - def multi_equals(value: Any) -> udf: # noqa: ANN401 """Create a user-defined function that checks if all the given columns have the designated value. diff --git a/tests/test_functions.py b/tests/test_functions.py index 7277bdcf..c3a0e3f4 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -97,46 +97,6 @@ def test_anti_trim(): chispa.assert_column_equality(actual_df, "words_anti_trimmed", "expected") -def test_exists(): - df = spark.createDataFrame( - [ - ([1, 2, 3], False), - ([4, 5, 6], True), - ([10, 11, 12], True), - ], - StructType( - [ - StructField("nums", ArrayType(IntegerType(), True), True), - StructField("expected", BooleanType(), True), - ] - ), - ) - actual_df = df.withColumn( - "any_num_greater_than_5", quinn.exists(lambda n: n > 5)(F.col("nums")) - ) - chispa.assert_column_equality(actual_df, "any_num_greater_than_5", "expected") - - -def test_forall(): - df = spark.createDataFrame( - [ - ([1, 2, 3], False), - ([4, 5, 6], True), - ([10, 11, 12], True), - ], - StructType( - [ - StructField("nums", ArrayType(IntegerType(), True), True), - StructField("expected", BooleanType(), True), - ] - ), - ) - actual_df = df.withColumn( - "all_nums_greater_than_3", quinn.forall(lambda n: n > 3)(F.col("nums")) - ) - chispa.assert_column_equality(actual_df, "all_nums_greater_than_3", "expected") - - def test_multi_equals(): df = quinn.create_df( spark,