diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index ea76dfa25bd99..453b17834020e 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -15,7 +15,6 @@ # limitations under the License. # -import bisect import importlib import math @@ -25,7 +24,6 @@ from pandas.core.dtypes.inference import is_integer from pyspark.sql import functions as F, Column -from pyspark.sql.types import DoubleType from pyspark.pandas.spark import functions as SF from pyspark.pandas.missing import unsupported_function from pyspark.pandas.config import get_option @@ -182,22 +180,16 @@ def compute_hist(psdf, bins): colnames = sdf.columns bucket_names = ["__{}_bucket".format(colname) for colname in colnames] - # TODO(SPARK-49202): register this function in scala side - @F.udf(returnType=DoubleType()) - def binary_search_for_buckets(value): - # Given bins = [1.0, 2.0, 3.0, 4.0] - # the intervals are: - # [1.0, 2.0) -> 0.0 - # [2.0, 3.0) -> 1.0 - # [3.0, 4.0] -> 2.0 (the last bucket is a closed interval) - if value < bins[0] or value > bins[-1]: - raise ValueError(f"value {value} out of the bins bounds: [{bins[0]}, {bins[-1]}]") - - if value == bins[-1]: - idx = len(bins) - 2 - else: - idx = bisect.bisect(bins, value) - 1 - return float(idx) + # refers to org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets + def binary_search_for_buckets(value: Column): + index = SF.binary_search(F.lit(bins), value) + bucket = F.when(index >= 0, index).otherwise(-index - 2) + unboundErrMsg = F.lit(f"value %s out of the bins bounds: [{bins[0]}, {bins[-1]}]") + return ( + F.when(value == F.lit(bins[-1]), F.lit(len(bins) - 2)) + .when(value.between(F.lit(bins[0]), F.lit(bins[-1])), bucket) + .otherwise(F.raise_error(F.printf(unboundErrMsg, value))) + ) output_df = ( sdf.select( @@ -205,10 +197,10 @@ def binary_search_for_buckets(value): F.array([F.col(colname).cast("double") for colname in colnames]) ).alias("__group_id", "__value") ) - # to match handleInvalid="skip" in Bucketizer - .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()).select( + .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()) + .select( F.col("__group_id"), - binary_search_for_buckets(F.col("__value")).alias("__bucket"), + binary_search_for_buckets(F.col("__value")).cast("double").alias("__bucket"), ) ) diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 6bef3d9b87c05..6aaa63956c14b 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -187,6 +187,19 @@ def collect_top_k(col: Column, num: int, reverse: bool) -> Column: return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse)) +def binary_search(col: Column, value: Column) -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns + + return _invoke_function_over_columns("array_binary_search", col, value) + + else: + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.binary_search(col._jc, value._jc)) + + def make_interval(unit: str, e: Union[Column, int, float]) -> Column: unit_mapping = { "YEAR": "years", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index c1c9af2ea4273..7dbc586f64730 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -152,6 +152,9 @@ private[sql] object PythonSQLUtils extends Logging { def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) + def binary_search(e: Column, value: Column): Column = + Column.internalFn("array_binary_search", e, value) + def pandasProduct(e: Column, ignoreNA: Boolean): Column = Column.internalFn("pandas_product", e, lit(ignoreNA))