From 339d1c9d9d50bce63316ac788626bea998e71b06 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 4 Sep 2024 09:38:39 +0800 Subject: [PATCH] [SPARK-49202][PS] Apply `ArrayBinarySearch` for histogram ### What changes were proposed in this pull request? Apply `ArrayBinarySearch` for histogram ### Why are the changes needed? this expression is dedicated for histogram, and supports codegen ``` (5) Project [codegen id : 1] Output [2]: [__group_id#37, cast(CASE WHEN ((__value#38 >= 1.0) AND (__value#38 <= 12.0)) THEN CASE WHEN (__value#38 = 12.0) THEN 11 WHEN (static_invoke(ArrayExpressionUtils.binarySearchNullSafe([1.0,1.9166666666666665,2.833333333333333,3.75,4.666666666666666,5.583333333333333,6.5,7.416666666666666,8.333333333333332,9.25,10.166666666666666,11.083333333333332,12.0], __value#38)) > 0) THEN static_invoke(ArrayExpressionUtils.binarySearchNullSafe([1.0,1.9166666666666665,2.833333333333333,3.75,4.666666666666666,5.583333333333333,6.5,7.416666666666666,8.333333333333332,9.25,10.166666666666666,11.083333333333332,12.0], __value#38)) ELSE (-static_invoke(ArrayExpressionUtils.binarySearchNullSafe([1.0,1.9166666666666665,2.833333333333333,3.75,4.666666666666666,5.583333333333333,6.5,7.416666666666666,8.333333333333332,9.25,10.166666666666666,11.083333333333332,12.0], __value#38)) - 2) END WHEN isnan(__value#38) THEN cast(raise_error(USER_RAISED_EXCEPTION, map(keys: [errorMessage], values: [Histogram encountered NaN value.]), NullType) as int) ELSE cast(raise_error(USER_RAISED_EXCEPTION, map(errorMessage, printf(value %s out of the bins bounds: [%s, %s], __value#38, 1.0, 12.0)), NullType) as int) END as double) AS __bucket#46] Input [2]: [__group_id#37, __value#38] ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI and manually check ### Was this patch authored or co-authored using generative AI tooling? No Closes #47970 from zhengruifeng/ps_apply_binary_search. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/pandas/plot/core.py | 34 +++++++------------ python/pyspark/pandas/spark/functions.py | 13 +++++++ .../spark/sql/api/python/PythonSQLUtils.scala | 3 ++ 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index ea76dfa25bd99..453b17834020e 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -15,7 +15,6 @@ # limitations under the License. # -import bisect import importlib import math @@ -25,7 +24,6 @@ from pandas.core.dtypes.inference import is_integer from pyspark.sql import functions as F, Column -from pyspark.sql.types import DoubleType from pyspark.pandas.spark import functions as SF from pyspark.pandas.missing import unsupported_function from pyspark.pandas.config import get_option @@ -182,22 +180,16 @@ def compute_hist(psdf, bins): colnames = sdf.columns bucket_names = ["__{}_bucket".format(colname) for colname in colnames] - # TODO(SPARK-49202): register this function in scala side - @F.udf(returnType=DoubleType()) - def binary_search_for_buckets(value): - # Given bins = [1.0, 2.0, 3.0, 4.0] - # the intervals are: - # [1.0, 2.0) -> 0.0 - # [2.0, 3.0) -> 1.0 - # [3.0, 4.0] -> 2.0 (the last bucket is a closed interval) - if value < bins[0] or value > bins[-1]: - raise ValueError(f"value {value} out of the bins bounds: [{bins[0]}, {bins[-1]}]") - - if value == bins[-1]: - idx = len(bins) - 2 - else: - idx = bisect.bisect(bins, value) - 1 - return float(idx) + # refers to org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets + def binary_search_for_buckets(value: Column): + index = SF.binary_search(F.lit(bins), value) + bucket = F.when(index >= 0, index).otherwise(-index - 2) + unboundErrMsg = F.lit(f"value %s out of the bins bounds: [{bins[0]}, {bins[-1]}]") + return ( + F.when(value == F.lit(bins[-1]), F.lit(len(bins) - 2)) + .when(value.between(F.lit(bins[0]), F.lit(bins[-1])), bucket) + .otherwise(F.raise_error(F.printf(unboundErrMsg, value))) + ) output_df = ( sdf.select( @@ -205,10 +197,10 @@ def binary_search_for_buckets(value): F.array([F.col(colname).cast("double") for colname in colnames]) ).alias("__group_id", "__value") ) - # to match handleInvalid="skip" in Bucketizer - .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()).select( + .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()) + .select( F.col("__group_id"), - binary_search_for_buckets(F.col("__value")).alias("__bucket"), + binary_search_for_buckets(F.col("__value")).cast("double").alias("__bucket"), ) ) diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 6bef3d9b87c05..6aaa63956c14b 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -187,6 +187,19 @@ def collect_top_k(col: Column, num: int, reverse: bool) -> Column: return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse)) +def binary_search(col: Column, value: Column) -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns + + return _invoke_function_over_columns("array_binary_search", col, value) + + else: + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.binary_search(col._jc, value._jc)) + + def make_interval(unit: str, e: Union[Column, int, float]) -> Column: unit_mapping = { "YEAR": "years", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index c1c9af2ea4273..7dbc586f64730 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -152,6 +152,9 @@ private[sql] object PythonSQLUtils extends Logging { def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) + def binary_search(e: Column, value: Column): Column = + Column.internalFn("array_binary_search", e, value) + def pandasProduct(e: Column, ignoreNA: Boolean): Column = Column.internalFn("pandas_product", e, lit(ignoreNA))